In [1]:
from torchsummary import summary
import torch
from torch import nn
import imageio
import torch
import glob
from fastai.vision import *
import os
from torch import nn
import torch.nn.functional as F

In [4]:
image_loc = '/Users/henriwoodcock/Documents/Code/data_projects/automatic-asset-classification/'

In [5]:
image_loc += "data/final_dataset/final/"

In [28]:
tens = torch.randn((384, 64,124,124))

In [29]:
tens.shape

torch.Size([384, 64, 124, 124])

In [26]:
class encodedFloatItem(ItemBase):
    "Basic class for float items."
    def __init__(self,obj): self.data,self.obj = np.array(obj).astype(np.float32),obj
    def __str__(self):  return str(self.data.shape)
    def __hash__(self): return hash(str(self))



In [27]:
class encodedList(FloatList):
    def __init__(self, items:Iterator, log:bool=False, **kwargs):
        if isinstance(items, ItemList):
            items = items.items
        super(FloatList,self).__init__(items,**kwargs)
    def get(self,i):
        return encodedFloatItem(torch.Tensor(super(FloatList,self).get(i).astype('float32')))

In [30]:
dat = encodedList(tens).split_none()

In [31]:
dat

ItemLists;

Train: encodedList (384 items)
(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124)
Path: .;

Valid: encodedList (0 items)

Path: .;

Test: None

In [33]:
tens2 = torch.randn((96, 64,124,124))

In [34]:
dat.valid = encodedList(tens2)

In [35]:
dat = dat.label_from_lists(dat.train,dat.valid)

In [36]:
dat

LabelLists;

Train: LabelList (384 items)
x: encodedList
(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124)
y: ItemList
(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124)
Path: .;

Valid: LabelList (96 items)
x: encodedList
(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124)
y: ItemList
(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124),(64, 124, 124)
Path: .;

Test: None

In [39]:
class AutoEncoder(nn.Module):

    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(64, 112, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            nn.BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor = 2, mode = 'bilinear'),
            nn.Conv2d(112, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.ReLU(inplace=True)
        )
        
    def encode(self,x): return self.encoder(x)
    
    def decode(self,x): return torch.clamp(self.decoder(x), min = 0, max=1)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return torch.clamp(decoded, min=0, max=1)

In [40]:
ae = AutoEncoder()

In [50]:
dat.path = image_loc

In [55]:
dat = dat.databunch()

In [56]:
learn = Learner(dat, ae, loss_func=F.mse_loss)

In [57]:
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,time


  "See the documentation of nn.Upsample for details.".format(mode))


KeyboardInterrupt: 

## Rewriting Learner;

In [59]:
learn.summary()

AutoEncoder
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [112, 62, 62]        351,232    True      
______________________________________________________________________
BatchNorm2d          [112, 62, 62]        224        True      
______________________________________________________________________
ReLU                 [112, 62, 62]        0          False     
______________________________________________________________________
Upsample             [112, 124, 124]      0          False     
______________________________________________________________________
Conv2d               [64, 124, 124]       64,512     True      
______________________________________________________________________
ReLU                 [64, 124, 124]       0          False     
______________________________________________________________________

Total params: 415,968
Total trainable params: 415,968
Total non-trainable params: 0
Optimized with 'torch.optim.a

In [63]:
from fastai.basic_train import Learner
from fastai.basic_data import DataBunch, DatasetType
from fastai.basic_train import get_preds
from fastai.callback import CallbackHandler

from fastai.basic_train import Learner
from fastai.callbacks.hooks import HookCallback

In [64]:
class ReplaceTargetCallback(LearnerCallback):
    """Callback to modify the loss of the learner to compute the loss against x"""
    _order = 9999
    
    def __init__(self,learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self,last_input,last_target,train,**kwargs):
        # We keep the original x to compute the reconstruction loss
        if not self.learn.inferring:
            return {"last_input" : last_input,"last_target" : last_input}
        else:
            return {"last_input" : last_input,"last_target" : last_target} 

In [83]:
class AutoEncoderLearner(Learner):
    def __init__(self,data:DataBunch,rec_loss:str,model:nn.Module,**kwargs):
        self.model = model
        
        assert rec_loss in ["mse","ce"],"Loss function must be mse or ce"
        if rec_loss == "mse":
            self.rec_loss = nn.MSELoss(reduction="none")
        else:
            self.rec_loss = nn.CrossEntropyLoss(reduction="none")
        
        self.inferring = False
        
        ae = model
        
        super().__init__(data, ae, loss_func=self.loss_func, **kwargs)
        
        # Callback to replace y with x during the training loop
        replace_cb = ReplaceTargetCallback(self)
        self.callbacks.append(replace_cb)
        
    def loss_func(self,x_rec,x,**kwargs):
        bs = x.shape[0]
        if isinstance(self.rec_loss,nn.MSELoss):
            l = self.rec_loss(x, x_rec).view(bs, -1).sum(dim=-1).mean()
        else:
            # First we discretize x and turn it from (B,1,H,W) to (B,H,W)
            x = (x * 256).long().squeeze(1)
            l = self.rec_loss(x_rec,x).view(bs,-1).sum(dim=-1).mean()
        return l
                
    def get_error(self, ds_type:DatasetType=DatasetType.Valid,activ:nn.Module=None, n_batch=None, pbar=None):
        "Return predictions and targets on `ds_type` dataset."
        
        x_rec,x = get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
                         activ=activ, loss_func=None, n_batch=n_batch, pbar=pbar)
        loss_func = lambda x_rec,x : self.rec_loss(x, x_rec, reduction='none').view(x.shape[0], -1).sum(dim=-1)
        l = loss_func(x_rec,x)
        return l

In [90]:
tens2

tensor([[[[ 1.6679e+00,  1.9228e-01, -4.6040e-01,  ...,  7.6809e-01,
            2.0479e+00, -7.6677e-01],
          [-9.8575e-01,  3.0159e-01,  9.7556e-01,  ..., -5.7564e-01,
            4.3427e-03,  5.7943e-01],
          [ 1.9088e-01,  7.3313e-01,  1.4619e-01,  ..., -1.8146e+00,
           -1.3616e+00, -2.1465e-01],
          ...,
          [-1.0969e+00, -5.6939e-01, -3.4922e-01,  ..., -1.0849e+00,
           -1.5564e+00, -7.6119e-01],
          [ 1.6886e-02, -1.1842e+00, -1.5916e+00,  ...,  1.0494e+00,
           -3.1976e-01,  2.5257e-01],
          [-1.2096e+00, -2.1255e-01, -5.6928e-01,  ..., -1.8409e-01,
           -4.5702e-01, -1.6194e+00]],

         [[-3.2417e-01,  5.5885e-01, -1.5489e+00,  ..., -4.4521e-02,
           -8.5023e-01, -1.9721e+00],
          [-3.3657e-01,  5.3394e-01, -2.2566e-01,  ...,  1.3484e+00,
           -1.1849e+00,  5.6131e-01],
          [ 1.7153e+00,  2.4593e-01, -8.7616e-01,  ...,  2.0835e-02,
            3.6075e-01, -2.2893e+00],
          ...,
     

In [None]:
dat = encodedList(tens).split_none().label_const(0)
dat.valid = dat.train

In [None]:
dat = dat.databunch()

In [84]:
learn = AutoEncoderLearner(data = dat, rec_loss = 'mse', model = ae)

In [85]:
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,time


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/fastai/torch_core.py", line 127, in data_collate
    return torch.utils.data.dataloader.default_collate(to_data(batch))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 62, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object


In [89]:
learn.data.train_dl.

AttributeError: 'ItemList' object has no attribute 'data'

In [20]:
size = 224
batchsize = 32
tfms = get_transforms(do_flip = False)
src = (ImageImageList.from_folder(image_loc).split_by_rand_pct(seed=2).label_from_func(lambda x: x))
data = (src.transform(size=size, tfm_y=True)
        .databunch(bs=batchsize)
        .normalize(imagenet_stats))#, do_y = False))

In [21]:
data

ImageDataBunch;

Train: LabelList (444 items)
x: ImageImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
Path: /Users/henriwoodcock/Documents/Code/data_projects/automatic-asset-classification/data/final_dataset/final;

Valid: LabelList (110 items)
x: ImageImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
Path: /Users/henriwoodcock/Documents/Code/data_projects/automatic-asset-classification/data/final_dataset/final;

Test: None