In [1]:
from fastai.imports import *
from fastai.vision import *
from fastai.data_block import *
from fastai.basic_train import *
from keras.datasets import mnist
from src.data_loader import Shifted_Data_Loader

import pandas as pd

Using TensorFlow backend.


In [2]:
# path = untar_data(URLs.MNIST_SAMPLE)
# data = ImageDataBunch.from_folder(path,bs=32)

In [3]:
class ArraysImageItemList(ImageList,FloatList):
    def __init__(self, items:Iterator, log:bool=False, **kwargs):
        if isinstance(items, ItemList):
            items = items.items
        super(FloatList,self).__init__(items,**kwargs)
    
    def get(self,i):
        return Tensor(super(FloatList,self).get(i).astype('float32'))

In [4]:
def process_keras_mnist(x):
    x = np.squeeze(x)
    x = x/255.
    x = x.reshape(-1,x.shape[-2],x.shape[-1])
    x = np.stack([x,x,x],1)
    return x

In [5]:
DL = Shifted_Data_Loader('mnist',rotation=None,translation=0.8,flatten=False,seed=7)


input_shape:  (56, 56, 1)
dataset:  mnist
scale:  2
tx_max:  0.8
rot_max:  None
loading mnist...
sx_train:  (60000, 56, 56, 1)
making training data...
making testing data...


In [6]:
x_tr = process_keras_mnist(DL.sx_train)
x_te = process_keras_mnist(DL.sx_test)

In [7]:
x_tr.shape

(60000, 3, 56, 56)

In [8]:
x_il = ArraysImageItemList(x_tr)
x_ils = x_il.split_by_rand_pct()

In [9]:
lls = x_ils.label_from_lists(x_ils.train, x_ils.valid)

In [10]:
data = lls.databunch()

In [11]:
class UpSample(nn.Module):
    def __init__(self,feat_in,feat_out,out_shape=None,scale=2):
        super().__init__()
        self.conv = nn.Conv2d(feat_in,feat_out,kernel_size=(3,3),stride=1,padding=1)
        self.out_shape,self.scale = out_shape,scale
        
    
    def forward(self,x):
        return self.conv(
            nn.functional.interpolate(
                x,size=self.out_shape,scale_factor=self.scale,mode='bilinear',align_corners=True))
    
def get_upSamp(feat_in,feat_out, out_shape=None, scale=2, act='relu'):
    
    upSamp = UpSample(feat_in,feat_out,out_shape=out_shape,scale=scale)
    
    layer = nn.Sequential(upSamp)
    
    if act == 'relu':
        act_f = nn.ReLU(inplace=True)
        bn = nn.BatchNorm2d(feat_out)
        layer.add_module('ReLU',act_f)
        layer.add_module('BN',bn)
    elif act == 'sig':
        act_f = nn.Sigmoid()
        layer.add_module('Sigmoid',act_f)
    return layer

def add_layer(m,feat_in,feat_out,name,out_shape=None,scale=2,act='relu'):
    upSamp = get_upSamp(feat_in,feat_out,out_shape=out_shape,scale=scale,act=act)
    m.add_module(name,upSamp)

In [12]:
class DuplexAE(nn.Module):
    def __init__(self,input_shape,z_dim=5,y_dim=10,layer_szs=[3000,1500]):
        super(DuplexAE, self).__init__()
        self.y_dim=y_dim
        self.z_dim=z_dim
        self.input_shape = input_shape
        self.lat_sz = y_dim+z_dim

        self.fc1 = nn.Linear(np.prod(input_shape),layer_szs[0])
        self.fc2 = nn.Linear(layer_szs[0], layer_szs[1])
        
        jmp = 256
        self.ups1 = UpSample(self.lat_sz,jmp)
        self.bn1 = nn.BatchNorm2d(jmp)
        self.ups2 = UpSample(jmp,int(jmp/2))
        self.bn2 = nn.BatchNorm2d(int(jmp/2))
        self.ups3 = UpSample(int(jmp/2),int(jmp/4),out_shape=(14,14),scale=None)
        self.bn3 = nn.BatchNorm2d(int(jmp/4))
        self.ups4 = UpSample(int(jmp/4),int(jmp/8))
        self.bn4 = nn.BatchNorm2d(int(jmp/8))
        self.ups5 = UpSample(int(jmp/8),3)

        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        lat = F.relu(self.fc2(x))        
#         y_lat,z_lat = torch.split(lat,[self.y_dim,self.z_dim],dim=1)
#         lat_cat = torch.cat([y_lat,z_lat],dim=1)
        x = lat.unsqueeze(2)
        x = x.unsqueeze(3)
        x = self.bn1(F.relu(self.ups1(x)))
        x = self.bn2(F.relu(self.ups2(x)))
        x = self.bn3(self.bn3(self.ups3(x)))
        x = self.bn4(self.ups4(x))
        out = torch.sigmoid(self.ups5(x))

        return out
    
    def name(self):
        return "DuplexAE"

In [13]:
# dl = data.dl()

In [14]:
t,_ = next(iter(data.dl()))
t.size()

torch.Size([64, 3, 56, 56])

In [15]:
# list(get_upSamp(15,32,'CodeOut').modules())

In [16]:
mod = DuplexAE(input_shape=(3,56,56))
# layer = nn.Conv2d(15,32,kernel_size=(3,3), stride=(1,1), padding=(1,1))
# mod.add_module('upsamp1',layer)
# mod.cuda()
# mod(t).size(),t.size()
# add_layer(mod,10+5,256,'CodeOut')

In [17]:
# mod = DuplexAE(y_dim=10,z_dim=5)
learn = Learner(data=data,model=mod,loss_func=F.mse_loss)

In [18]:
learn.model(t).size(),t.size()

(torch.Size([484, 3, 56, 56]), torch.Size([64, 3, 56, 56]))

In [19]:
learn.lr_find(start_lr=1e-3,end_lr=1000)

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


RuntimeError: The size of tensor a (484) must match the size of tensor b (64) at non-singleton dimension 0

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(10,3e-2)