-
Notifications
You must be signed in to change notification settings - Fork 103
/
Vision.hs
648 lines (617 loc) · 26.4 KB
/
Vision.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Torch.Vision where
import qualified Codec.Picture as I
import Control.Exception.Safe
( SomeException (..),
throwIO,
try,
)
import Control.Monad
( MonadPlus,
forM_,
when,
)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Foreign.Marshal.Utils (copyBytes)
import Data.Int
import Data.Kind (Type)
import qualified Data.Vector.Storable as V
import Data.Word
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import GHC.Exts (IsList (fromList))
import qualified Language.C.Inline as C
import Pipes
import System.IO.Unsafe
import System.Random (mkStdGen, randoms)
import qualified Torch.DType as D
import Torch.Data.Pipeline
import Torch.Data.StreamedPipeline
import Torch.Functional hiding (take)
import qualified Torch.Functional as D
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import Torch.NN
import Torch.Tensor
import qualified Torch.Tensor as D
import qualified Torch.TensorOptions as D
import qualified Torch.Typed.Vision as I
import Prelude hiding (max, min)
import qualified Prelude as P
C.include "<stdint.h>"
data MNIST (m :: Type -> Type) = MNIST
{ batchSize :: Int,
mnistData :: I.MnistData
}
instance Monad m => Datastream m Int (MNIST m) (Tensor, Tensor) where
streamSamples MNIST {..} seed = Select $
for (each [1 .. numIters]) $
\iter -> do
let from = (iter -1) * batchSize
to = (iter * batchSize) - 1
indexes = [from .. to]
target = getLabels' batchSize mnistData indexes
let input = getImages' batchSize 784 mnistData indexes
yield (input, target)
where
numIters = I.length mnistData `Prelude.div` batchSize
instance Applicative m => Dataset m (MNIST m) Int (Tensor, Tensor) where
getItem MNIST {..} ix =
let indexes = [ix * batchSize .. (ix + 1) * batchSize - 1]
imgs = getImages' batchSize 784 mnistData indexes
labels = getLabels' batchSize mnistData indexes
in pure (imgs, labels)
keys MNIST {..} = fromList [0 .. I.length mnistData `Prelude.div` batchSize - 1]
getLabels' :: Int -> I.MnistData -> [Int] -> Tensor
getLabels' n mnist imageIdxs =
asTensor $ map (I.getLabel mnist) . take n $ imageIdxs
getImages' ::
Int -> -- number of observations in minibatch
Int -> -- dimensionality of the data
I.MnistData -> -- mnist data representation
[Int] -> -- indices of the dataset
Tensor
getImages' n dataDim mnist imageIdxs = unsafePerformIO $ do
let (BSI.PS fptr off len) = I.images mnist
t <-
(cast2 LibTorch.empty_lo :: [Int] -> D.TensorOptions -> IO D.Tensor)
[n, dataDim]
(D.withDType D.UInt8 D.defaultOpts)
D.withTensor t $ \ptr1 -> do
F.withForeignPtr fptr $ \ptr2 -> do
forM_ (zip [0 .. (n -1)] imageIdxs) $ \(i, idx) -> do
copyBytes
(F.plusPtr ptr1 (dataDim * i))
(F.plusPtr ptr2 (off + 16 + dataDim * idx))
dataDim
return $ D.toType D.Float t
-- http://paulbourke.net/dataformats/asciiart/
grayScale10 = " .:-=+*#%@"
grayScale70 = reverse "$@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\"^`'. "
-- Display an MNIST image tensor as ascii text
dispImage :: Tensor -> IO ()
dispImage img = do
mapM
( \row ->
mapM
( \col ->
putChar $ grayScale !! (P.floor $ scaled !! row !! col)
)
[0, downSamp .. 27]
>> putStrLn ""
)
[0, downSamp .. 27]
pure ()
where
downSamp = 2
grayScale = grayScale10
paletteMax = (fromIntegral $ length grayScale) - 1.0
img' = reshape [28, 28] img
scaled :: [[Float]] =
let (mn, mx) = (min img', max img')
in asValue $ (img' - mn) / (mx - mn) * paletteMax
data PixelFormat
= Y8
| YF
| YA8
| RGB8
| RGBF
| RGBA8
| YCbCr8
| CMYK8
| CMYK16
| RGBA16
| RGB16
| Y16
| YA16
| Y32
deriving (Show, Eq)
readImage :: FilePath -> IO (Either String (D.Tensor, PixelFormat))
readImage file =
I.readImage file >>= \case
Left err -> return $ Left err
Right img' -> return $ Right $ (fromDynImage img', pixelFormat img')
readImageAsRGB8 :: FilePath -> IO (Either String D.Tensor)
readImageAsRGB8 file =
I.readImage file >>= \case
Left err -> return $ Left err
Right img' -> return . Right . fromDynImage . I.ImageRGB8 . I.convertRGB8 $ img'
readImageAsRGB8WithScaling :: FilePath -> Int -> Int -> Bool -> IO (Either String (I.Image I.PixelRGB8, D.Tensor))
readImageAsRGB8WithScaling file width height keepAspectRatio =
I.readImage file >>= \case
Left err -> return $ Left err
Right img' -> do
let img = (resizeRGB8 width height keepAspectRatio) . I.convertRGB8 $ img'
return $ Right (img, fromDynImage . I.ImageRGB8 $ img)
centerCrop :: Int -> Int -> I.Image I.PixelRGB8 -> I.Image I.PixelRGB8
centerCrop width height input = unsafePerformIO $ do
let channel = 3 :: Int
(I.Image org_w org_h org_vec) = input
img@(I.Image w h vec) = I.generateImage (\_ _ -> (I.PixelRGB8 0 0 0)) width height :: I.Image I.PixelRGB8
(org_fptr, org_len) = V.unsafeToForeignPtr0 org_vec
org_whc = fromIntegral $ org_w * org_h * channel
(fptr, len) = V.unsafeToForeignPtr0 vec
whc = fromIntegral $ w * h * channel
F.withForeignPtr org_fptr $ \ptr1 -> F.withForeignPtr fptr $ \ptr2 -> do
let src = F.castPtr ptr1
dst = F.castPtr ptr2
iw = fromIntegral w
ih = fromIntegral h
iorg_w = fromIntegral org_w
iorg_h = fromIntegral org_h
ichannel = fromIntegral channel
[C.block| void {
uint8_t* src = $(uint8_t* src);
uint8_t* dst = $(uint8_t* dst);
int w = $(int iw);
int h = $(int ih);
int channel = $(int ichannel);
int ow = $(int iorg_w);
int oh = $(int iorg_h);
int offsetx = (ow - w)/2;
int offsety = (oh - h)/2;
for(int y=0;y<h;y++){
for(int x=0;x<w;x++){
for(int c=0;c<channel;c++){
int sy = y + offsety;
int sx = x + offsetx;
if(sx >= 0 && sx < ow &&
sy >= 0 && sy < oh){
dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
}
}
}
}
} |]
return img
drawLine :: Int -> Int -> Int -> Int -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawLine x0 y0 x1 y1 (r, g, b) input = do
let img@(I.Image w h vec) = input
(fptr, len) = V.unsafeToForeignPtr0 vec
F.withForeignPtr fptr $ \ptr2 -> do
let iw = fromIntegral w
ih = fromIntegral h
ix0 = fromIntegral x0
iy0 = fromIntegral y0
ix1 = fromIntegral x1
iy1 = fromIntegral y1
ir = fromIntegral r
ig = fromIntegral g
ib = fromIntegral b
dst = F.castPtr ptr2
[C.block| void {
uint8_t* dst = $(uint8_t* dst);
int w = $(int iw);
int h = $(int ih);
int x0 = $(int ix0);
int y0 = $(int iy0);
int x1 = $(int ix1);
int y1 = $(int iy1);
int r = $(int ir);
int g = $(int ig);
int b = $(int ib);
int channel = 3;
int sign_x = x1 - x0 >= 0 ? 1 : -1;
int sign_y = y1 - y0 >= 0 ? 1 : -1;
int abs_x = x1 - x0 >= 0 ? x1 - x0 : x0 - x1;
int abs_y = y1 - y0 >= 0 ? y1 - y0 : y0 - y1;
if(abs_x>=abs_y){
for(int x=x0;x!=x1;x+=sign_x){
int y = (x-x0) * (y1-y0) / (x1-x0) + y0;
if(y >=0 && y < h &&
x >=0 && x < w) {
dst[(y*w+x)*channel+0] = r;
dst[(y*w+x)*channel+1] = g;
dst[(y*w+x)*channel+2] = b;
}
}
} else {
for(int y=y0;y!=y1;y+=sign_y){
int x = (y-y0) * (x1-x0) / (y1-y0) + x0;
if(y >=0 && y < h &&
x >=0 && x < w) {
dst[(y*w+x)*channel+0] = r;
dst[(y*w+x)*channel+1] = g;
dst[(y*w+x)*channel+2] = b;
}
}
}
} |]
drawRect :: Int -> Int -> Int -> Int -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawRect x0 y0 x1 y1 (r, g, b) input = do
drawLine x0 y0 (x1 + 1) y0 (r, g, b) input
drawLine x0 y0 x0 (y1 + 1) (r, g, b) input
drawLine x0 y1 (x1 + 1) y1 (r, g, b) input
drawLine x1 y0 x1 (y1 + 1) (r, g, b) input
drawString :: String -> Int -> Int -> (Int, Int, Int) -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawString text x0 y0 (r, g, b) (br, bg, bb) input = do
forM_ (zip [0 ..] text) $ \(i, ch) -> do
drawChar (fromEnum ch) (x0 + i * 8) y0 (r, g, b) (br, bg, bb) input
drawChar :: Int -> Int -> Int -> (Int, Int, Int) -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawChar ascii_code x0 y0 (r, g, b) (br, bg, bb) input = do
let img@(I.Image w h vec) = input
(fptr, len) = V.unsafeToForeignPtr0 vec
F.withForeignPtr fptr $ \ptr2 -> do
let iw = fromIntegral w
ih = fromIntegral h
ix0 = fromIntegral x0
iy0 = fromIntegral y0
ir = fromIntegral r
ig = fromIntegral g
ib = fromIntegral b
ibr = fromIntegral br
ibg = fromIntegral bg
ibb = fromIntegral bb
dst = F.castPtr ptr2
iascii_code = fromIntegral ascii_code
[C.block| void {
uint8_t* dst = $(uint8_t* dst);
int w = $(int iw);
int h = $(int ih);
int x0 = $(int ix0);
int y0 = $(int iy0);
int r = $(int ir);
int g = $(int ig);
int b = $(int ib);
int br = $(int ibr);
int bg = $(int ibg);
int bb = $(int ibb);
int ascii_code = $(int iascii_code);
int channel = 3;
int char_width = 8;
int char_height = 8;
char fonts[95][8] = { // 0x20 to 0x7e
{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{ 0x18, 0x3C, 0x3C, 0x18, 0x18, 0x00, 0x18, 0x00},
{ 0x36, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{ 0x36, 0x36, 0x7F, 0x36, 0x7F, 0x36, 0x36, 0x00},
{ 0x0C, 0x3E, 0x03, 0x1E, 0x30, 0x1F, 0x0C, 0x00},
{ 0x00, 0x63, 0x33, 0x18, 0x0C, 0x66, 0x63, 0x00},
{ 0x1C, 0x36, 0x1C, 0x6E, 0x3B, 0x33, 0x6E, 0x00},
{ 0x06, 0x06, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00},
{ 0x18, 0x0C, 0x06, 0x06, 0x06, 0x0C, 0x18, 0x00},
{ 0x06, 0x0C, 0x18, 0x18, 0x18, 0x0C, 0x06, 0x00},
{ 0x00, 0x66, 0x3C, 0xFF, 0x3C, 0x66, 0x00, 0x00},
{ 0x00, 0x0C, 0x0C, 0x3F, 0x0C, 0x0C, 0x00, 0x00},
{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x0C, 0x06},
{ 0x00, 0x00, 0x00, 0x3F, 0x00, 0x00, 0x00, 0x00},
{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x0C, 0x00},
{ 0x60, 0x30, 0x18, 0x0C, 0x06, 0x03, 0x01, 0x00},
{ 0x3E, 0x63, 0x73, 0x7B, 0x6F, 0x67, 0x3E, 0x00},
{ 0x0C, 0x0E, 0x0C, 0x0C, 0x0C, 0x0C, 0x3F, 0x00},
{ 0x1E, 0x33, 0x30, 0x1C, 0x06, 0x33, 0x3F, 0x00},
{ 0x1E, 0x33, 0x30, 0x1C, 0x30, 0x33, 0x1E, 0x00},
{ 0x38, 0x3C, 0x36, 0x33, 0x7F, 0x30, 0x78, 0x00},
{ 0x3F, 0x03, 0x1F, 0x30, 0x30, 0x33, 0x1E, 0x00},
{ 0x1C, 0x06, 0x03, 0x1F, 0x33, 0x33, 0x1E, 0x00},
{ 0x3F, 0x33, 0x30, 0x18, 0x0C, 0x0C, 0x0C, 0x00},
{ 0x1E, 0x33, 0x33, 0x1E, 0x33, 0x33, 0x1E, 0x00},
{ 0x1E, 0x33, 0x33, 0x3E, 0x30, 0x18, 0x0E, 0x00},
{ 0x00, 0x0C, 0x0C, 0x00, 0x00, 0x0C, 0x0C, 0x00},
{ 0x00, 0x0C, 0x0C, 0x00, 0x00, 0x0C, 0x0C, 0x06},
{ 0x18, 0x0C, 0x06, 0x03, 0x06, 0x0C, 0x18, 0x00},
{ 0x00, 0x00, 0x3F, 0x00, 0x00, 0x3F, 0x00, 0x00},
{ 0x06, 0x0C, 0x18, 0x30, 0x18, 0x0C, 0x06, 0x00},
{ 0x1E, 0x33, 0x30, 0x18, 0x0C, 0x00, 0x0C, 0x00},
{ 0x3E, 0x63, 0x7B, 0x7B, 0x7B, 0x03, 0x1E, 0x00},
{ 0x0C, 0x1E, 0x33, 0x33, 0x3F, 0x33, 0x33, 0x00},
{ 0x3F, 0x66, 0x66, 0x3E, 0x66, 0x66, 0x3F, 0x00},
{ 0x3C, 0x66, 0x03, 0x03, 0x03, 0x66, 0x3C, 0x00},
{ 0x1F, 0x36, 0x66, 0x66, 0x66, 0x36, 0x1F, 0x00},
{ 0x7F, 0x46, 0x16, 0x1E, 0x16, 0x46, 0x7F, 0x00},
{ 0x7F, 0x46, 0x16, 0x1E, 0x16, 0x06, 0x0F, 0x00},
{ 0x3C, 0x66, 0x03, 0x03, 0x73, 0x66, 0x7C, 0x00},
{ 0x33, 0x33, 0x33, 0x3F, 0x33, 0x33, 0x33, 0x00},
{ 0x1E, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
{ 0x78, 0x30, 0x30, 0x30, 0x33, 0x33, 0x1E, 0x00},
{ 0x67, 0x66, 0x36, 0x1E, 0x36, 0x66, 0x67, 0x00},
{ 0x0F, 0x06, 0x06, 0x06, 0x46, 0x66, 0x7F, 0x00},
{ 0x63, 0x77, 0x7F, 0x7F, 0x6B, 0x63, 0x63, 0x00},
{ 0x63, 0x67, 0x6F, 0x7B, 0x73, 0x63, 0x63, 0x00},
{ 0x1C, 0x36, 0x63, 0x63, 0x63, 0x36, 0x1C, 0x00},
{ 0x3F, 0x66, 0x66, 0x3E, 0x06, 0x06, 0x0F, 0x00},
{ 0x1E, 0x33, 0x33, 0x33, 0x3B, 0x1E, 0x38, 0x00},
{ 0x3F, 0x66, 0x66, 0x3E, 0x36, 0x66, 0x67, 0x00},
{ 0x1E, 0x33, 0x07, 0x0E, 0x38, 0x33, 0x1E, 0x00},
{ 0x3F, 0x2D, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
{ 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x3F, 0x00},
{ 0x33, 0x33, 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x00},
{ 0x63, 0x63, 0x63, 0x6B, 0x7F, 0x77, 0x63, 0x00},
{ 0x63, 0x63, 0x36, 0x1C, 0x1C, 0x36, 0x63, 0x00},
{ 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x0C, 0x1E, 0x00},
{ 0x7F, 0x63, 0x31, 0x18, 0x4C, 0x66, 0x7F, 0x00},
{ 0x1E, 0x06, 0x06, 0x06, 0x06, 0x06, 0x1E, 0x00},
{ 0x03, 0x06, 0x0C, 0x18, 0x30, 0x60, 0x40, 0x00},
{ 0x1E, 0x18, 0x18, 0x18, 0x18, 0x18, 0x1E, 0x00},
{ 0x08, 0x1C, 0x36, 0x63, 0x00, 0x00, 0x00, 0x00},
{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF},
{ 0x0C, 0x0C, 0x18, 0x00, 0x00, 0x00, 0x00, 0x00},
{ 0x00, 0x00, 0x1E, 0x30, 0x3E, 0x33, 0x6E, 0x00},
{ 0x07, 0x06, 0x06, 0x3E, 0x66, 0x66, 0x3B, 0x00},
{ 0x00, 0x00, 0x1E, 0x33, 0x03, 0x33, 0x1E, 0x00},
{ 0x38, 0x30, 0x30, 0x3e, 0x33, 0x33, 0x6E, 0x00},
{ 0x00, 0x00, 0x1E, 0x33, 0x3f, 0x03, 0x1E, 0x00},
{ 0x1C, 0x36, 0x06, 0x0f, 0x06, 0x06, 0x0F, 0x00},
{ 0x00, 0x00, 0x6E, 0x33, 0x33, 0x3E, 0x30, 0x1F},
{ 0x07, 0x06, 0x36, 0x6E, 0x66, 0x66, 0x67, 0x00},
{ 0x0C, 0x00, 0x0E, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
{ 0x30, 0x00, 0x30, 0x30, 0x30, 0x33, 0x33, 0x1E},
{ 0x07, 0x06, 0x66, 0x36, 0x1E, 0x36, 0x67, 0x00},
{ 0x0E, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
{ 0x00, 0x00, 0x33, 0x7F, 0x7F, 0x6B, 0x63, 0x00},
{ 0x00, 0x00, 0x1F, 0x33, 0x33, 0x33, 0x33, 0x00},
{ 0x00, 0x00, 0x1E, 0x33, 0x33, 0x33, 0x1E, 0x00},
{ 0x00, 0x00, 0x3B, 0x66, 0x66, 0x3E, 0x06, 0x0F},
{ 0x00, 0x00, 0x6E, 0x33, 0x33, 0x3E, 0x30, 0x78},
{ 0x00, 0x00, 0x3B, 0x6E, 0x66, 0x06, 0x0F, 0x00},
{ 0x00, 0x00, 0x3E, 0x03, 0x1E, 0x30, 0x1F, 0x00},
{ 0x08, 0x0C, 0x3E, 0x0C, 0x0C, 0x2C, 0x18, 0x00},
{ 0x00, 0x00, 0x33, 0x33, 0x33, 0x33, 0x6E, 0x00},
{ 0x00, 0x00, 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x00},
{ 0x00, 0x00, 0x63, 0x6B, 0x7F, 0x7F, 0x36, 0x00},
{ 0x00, 0x00, 0x63, 0x36, 0x1C, 0x36, 0x63, 0x00},
{ 0x00, 0x00, 0x33, 0x33, 0x33, 0x3E, 0x30, 0x1F},
{ 0x00, 0x00, 0x3F, 0x19, 0x0C, 0x26, 0x3F, 0x00},
{ 0x38, 0x0C, 0x0C, 0x07, 0x0C, 0x0C, 0x38, 0x00},
{ 0x18, 0x18, 0x18, 0x00, 0x18, 0x18, 0x18, 0x00},
{ 0x07, 0x0C, 0x0C, 0x38, 0x0C, 0x0C, 0x07, 0x00},
{ 0x6E, 0x3B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
};
for(int y=y0;y<y0+char_height;y++){
for(int x=x0;x<x0+char_width;x++){
if(y >=0 && y < h &&
x >=0 && x < w) {
int dx = x-x0;
int dy = y-y0;
int bit =
ascii_code > 0x20 && ascii_code < 0x7f ?
fonts[ascii_code-0x20][dy] & (0x1 << dx) :
0;
if (bit) {
dst[(y*w+x)*channel+0] = r;
dst[(y*w+x)*channel+1] = g;
dst[(y*w+x)*channel+2] = b;
} else {
dst[(y*w+x)*channel+0] = br;
dst[(y*w+x)*channel+1] = bg;
dst[(y*w+x)*channel+2] = bb;
}
}
}
}
} |]
resizeRGB8 :: Int -> Int -> Bool -> I.Image I.PixelRGB8 -> I.Image I.PixelRGB8
resizeRGB8 width height keepAspectRatio input = unsafePerformIO $ do
let channel = 3 :: Int
(I.Image org_w org_h org_vec) = input
img@(I.Image w h vec) = I.generateImage (\_ _ -> (I.PixelRGB8 0 0 0)) width height :: I.Image I.PixelRGB8
(org_fptr, org_len) = V.unsafeToForeignPtr0 org_vec
org_whc = fromIntegral $ org_w * org_h * channel
(fptr, len) = V.unsafeToForeignPtr0 vec
whc = fromIntegral $ w * h * channel
F.withForeignPtr org_fptr $ \ptr1 -> F.withForeignPtr fptr $ \ptr2 -> do
let src = F.castPtr ptr1
dst = F.castPtr ptr2
iw = fromIntegral w
ih = fromIntegral h
iorg_w = fromIntegral org_w
iorg_h = fromIntegral org_h
ichannel = fromIntegral channel
ckeepAspectRatio = if keepAspectRatio then 1 else 0
[C.block| void {
uint8_t* src = $(uint8_t* src);
uint8_t* dst = $(uint8_t* dst);
int w = $(int iw);
int h = $(int ih);
int channel = $(int ichannel);
int ow = $(int iorg_w);
int oh = $(int iorg_h);
int keepAspectRatio = $(int ckeepAspectRatio);
if(keepAspectRatio){
int t0h = h;
int t0w = ow * h / oh;
int t1h = oh * w / ow;
int t1w = w;
if (t0w > w) {
int offset = (h - (oh * w / ow))/2;
for(int y=offset;y<h-offset;y++){
for(int x=0;x<w;x++){
for(int c=0;c<channel;c++){
int sy = (y-offset) * ow / w;
int sx = x * ow / w;
if(sy >= 0 && sy < oh){
dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
}
}
}
}
} else {
int offset = (w - (ow * h / oh))/2;
for(int y=0;y<h;y++){
for(int x=offset;x<w-offset;x++){
for(int c=0;c<channel;c++){
int sy = y * oh / h;
int sx = (x-offset) * oh / h;
if(sx >= 0 && sx < ow){
dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
}
}
}
}
}
} else {
for(int y=0;y<h;y++){
for(int x=0;x<w;x++){
for(int c=0;c<channel;c++){
int sy = y * oh / h;
int sx = x * ow / w;
dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
}
}
}
}
} |]
return img
pixelFormat :: I.DynamicImage -> PixelFormat
pixelFormat image = case image of
I.ImageY8 _ -> Y8
I.ImageYF _ -> YF
I.ImageYA8 _ -> YA8
I.ImageRGB8 _ -> RGB8
I.ImageRGBF _ -> RGBF
I.ImageRGBA8 _ -> RGBA8
I.ImageYCbCr8 _ -> YCbCr8
I.ImageCMYK8 _ -> CMYK8
I.ImageCMYK16 _ -> CMYK16
I.ImageRGBA16 _ -> RGBA16
I.ImageRGB16 _ -> RGB16
I.ImageY16 _ -> Y16
I.ImageYA16 _ -> YA16
I.ImageY32 _ -> Y32
fromDynImage :: I.DynamicImage -> D.Tensor
fromDynImage image = unsafePerformIO $ case image of
I.ImageY8 (I.Image width height vec) -> createTensor width height 1 D.UInt8 1 vec
I.ImageYF (I.Image width height vec) -> createTensor width height 1 D.Float 4 vec
I.ImageYA8 (I.Image width height vec) -> createTensor width height 2 D.UInt8 1 vec
I.ImageRGB8 (I.Image width height vec) -> createTensor width height 3 D.UInt8 1 vec
I.ImageRGBF (I.Image width height vec) -> createTensor width height 3 D.Float 4 vec
I.ImageRGBA8 (I.Image width height vec) -> createTensor width height 4 D.UInt8 1 vec
I.ImageYCbCr8 (I.Image width height vec) -> createTensor width height 3 D.UInt8 1 vec
I.ImageCMYK8 (I.Image width height vec) -> createTensor width height 4 D.UInt8 1 vec
I.ImageCMYK16 (I.Image width height vec) -> createTensorU16to32 width height 4 D.Int32 vec
I.ImageRGBA16 (I.Image width height vec) -> createTensorU16to32 width height 4 D.Int32 vec
I.ImageRGB16 (I.Image width height vec) -> createTensorU16to32 width height 3 D.Int32 vec
I.ImageY16 (I.Image width height vec) -> createTensorU16to32 width height 1 D.Int32 vec
I.ImageYA16 (I.Image width height vec) -> createTensorU16to32 width height 2 D.Int32 vec
I.ImageY32 (I.Image width height vec) -> createTensorU32to64 width height 1 D.Int64 vec
where
createTensor width height channel dtype dtype_size vec = do
t <- ((cast2 LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [1, height, width, channel] $ D.withDType dtype D.defaultOpts
D.withTensor t $ \ptr1 -> do
let (fptr, len) = V.unsafeToForeignPtr0 vec
whc = width * height * channel * dtype_size
F.withForeignPtr fptr $ \ptr2 -> do
copyBytes (F.castPtr ptr1) (F.castPtr ptr2) whc
return t
createTensorU16to32 width height channel dtype vec = do
t <- ((cast2 LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [1, height, width, channel] $ D.withDType dtype D.defaultOpts
D.withTensor t $ \ptr1 -> do
let (fptr, len) = V.unsafeToForeignPtr0 vec
whc = fromIntegral $ width * height * channel
F.withForeignPtr fptr $ \ptr2 -> do
let src = F.castPtr ptr2
dst = F.castPtr ptr1
[C.block| void {
uint16_t* src = $(uint16_t* src);
int32_t* dst = $(int32_t* dst);
for(int i=0;i<$(int whc);i++){
dst[i] = src[i];
}
} |]
return t
createTensorU32to64 width height channel dtype vec = do
t <- ((cast2 LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [1, height, width, channel] $ D.withDType dtype D.defaultOpts
D.withTensor t $ \ptr1 -> do
let (fptr, len) = V.unsafeToForeignPtr0 vec
whc = fromIntegral $ width * height * channel
F.withForeignPtr fptr $ \ptr2 -> do
let src = F.castPtr ptr2
dst = F.castPtr ptr1
[C.block| void {
uint32_t* src = $(uint32_t* src);
int64_t* dst = $(int64_t* dst);
for(int i=0;i<$(int whc);i++){
dst[i] = src[i];
}
} |]
return t
fromImages :: [I.Image I.PixelRGB8] -> IO D.Tensor
fromImages imgs = do
let num_imgs = length imgs
channel = 3
(I.Image width height _) = head imgs
when (num_imgs == 0) $ do
throwIO $ userError "The number of images should be greater than 0."
t <- ((cast2 LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [num_imgs, height, width, channel] $ D.withDType D.UInt8 D.defaultOpts
D.withTensor t $ \ptr1 -> do
forM_ (zip [0 ..] imgs) $ \(idx, (I.Image width' height' vec)) -> do
let (fptr, len) = V.unsafeToForeignPtr0 vec
whc = width * height * channel
when (len /= whc) $ do
throwIO $ userError "vector's length is not the same as tensor' one."
when (width /= width') $ do
throwIO $ userError "image's width is not the same as first image's one"
when (height /= height') $ do
throwIO $ userError "image's height is not the same as first image's one"
F.withForeignPtr fptr $ \ptr2 -> do
copyBytes (F.plusPtr (F.castPtr ptr1) (whc * idx)) ptr2 len
return t
writeImage :: forall p. I.Pixel p => Int -> Int -> Int -> p -> D.Tensor -> IO (I.Image p)
writeImage width height channel pixel tensor = do
let img@(I.Image w h vec) = I.generateImage (\_ _ -> pixel) width height :: I.Image p
D.withTensor tensor $ \ptr1 -> do
let (fptr, len) = V.unsafeToForeignPtr0 vec
whc = width * height * channel
if (len /= whc)
then throwIO $ userError $ "vector's length(" ++ show len ++ ") is not the same as tensor' one."
else do
F.withForeignPtr fptr $ \ptr2 -> do
copyBytes (F.castPtr ptr2) (F.castPtr ptr1) len
return img
writeBitmap :: FilePath -> D.Tensor -> IO ()
writeBitmap file tensor = do
case (D.shape tensor, D.dtype tensor) of
([1, height, width, 1], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writeBitmap file
([1, height, width], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writeBitmap file
([1, height, width, 3], D.UInt8) -> writeImage width height 3 (I.PixelRGB8 0 0 0) tensor >>= I.writeBitmap file
([height, width, 1], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writeBitmap file
([height, width], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writeBitmap file
([height, width, 3], D.UInt8) -> writeImage width height 3 (I.PixelRGB8 0 0 0) tensor >>= I.writeBitmap file
format -> throwIO $ userError $ "Can not write a image. " ++ show format ++ " is unsupported format."
writePng :: FilePath -> D.Tensor -> IO ()
writePng file tensor = do
case (D.shape tensor, D.dtype tensor) of
([1, height, width, 1], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writePng file
([1, height, width], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writePng file
([1, height, width, 3], D.UInt8) -> writeImage width height 3 (I.PixelRGB8 0 0 0) tensor >>= I.writePng file
([height, width, 1], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writePng file
([height, width], D.UInt8) -> writeImage width height 1 (0 :: I.Pixel8) tensor >>= I.writePng file
([height, width, 3], D.UInt8) -> writeImage width height 3 (I.PixelRGB8 0 0 0) tensor >>= I.writePng file
format -> throwIO $ userError $ "Can not write a image. " ++ show format ++ " is unsupported format."
-- [batch, height, width, channel] -> [batch, channel, height, width]
hwc2chw :: D.Tensor -> D.Tensor
hwc2chw = D.permute [0, 3, 1, 2]
-- [batch, channel, height, width] -> [batch, height, width, channel]
chw2hwc :: D.Tensor -> D.Tensor
chw2hwc = D.permute [0, 2, 3, 1]
randomIndexes :: Int -> [Int]
randomIndexes size = (`mod` size) <$> randoms seed where seed = mkStdGen 123