In [1]:
from torchvision.models import resnet18
from torchvision import transforms
import numpy as np
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import pylab as pl
import time
import pyinn.ncrelu as ncrelu
from IPython import display
torch.cuda.set_device(3)

In [2]:
training_data = torch.load('/data/mini_imagenet/mIN_train.pth')
training_data += torch.load('/data/mini_imagenet/mIN_test.pth')
training_data += torch.load('/data/mini_imagenet/mIN_val.pth')
testing_data = [i[-50:] for i in training_data]
training_data = [i[:-50] for i in training_data]
print len(testing_data[0]),len(training_data[0])

50 550


In [3]:
from gpytorch.utils.fft import fftc,ifftc
from torch.autograd import Variable, Function
from torch.nn.parameter import Parameter
from pyinn import conv2d_depthwise as conv2
from torch.nn.functional import conv2d

class FFT1(Function):
    def __init__(self):
        super(FFT1, self).__init__()
    
    def forward(self, inp):
        return fftc(inp)
    
    def backward(self, grad):
        return ifftc(grad)
    
class iFFT1(Function):
    def __init__(self):
        super(iFFT1, self).__init__()
        
    def forward(self, inp):
        return ifftc(inp)
    
    def backward(self, grad):
        return fftc(grad)

class Conv3dF(nn.Module):
    def __init__(self,nchannels):
        super(Conv3dF,self).__init__()
        self.nchannels = nchannels/2+1
        self.striper = Variable(torch.bernoulli(torch.ones(1,nchannels,1,1)/2)*2-1, requires_grad=False).cuda()
                        #Conv2dD(channels = nchannels, kernel_size=3, padding=1)
        self.fft = FFT1()
        self.ifft = iFFT1()
        self.wre = Parameter(torch.randn(self.nchannels,1,3,3)/1000)
        self.wim = Parameter(torch.randn(self.nchannels,1,3,3)/1000)
                                                # TODO proper weight initialization
        
    def forward(self, inp):
        fi = self.fft(self.striper*inp)
        result = Variable(torch.zeros(*fi.size()), requires_grad=True).cuda()
        result[:,:,:,:,0].add(conv2d(fi[:,:,:,:,0], self.wre, padding=1, groups=self.nchannels))
        result[:,:,:,:,0].sub(conv2d(fi[:,:,:,:,1], self.wim, padding=1, groups=self.nchannels))
        result[:,:,:,:,1].add(conv2d(fi[:,:,:,:,0], self.wim, padding=1, groups=self.nchannels))
        result[:,:,:,:,1].add(conv2d(fi[:,:,:,:,1], self.wre, padding=1, groups=self.nchannels))

        return self.ifft(result)

In [4]:
class Block(nn.Module):
    def __init__(self, insize, outsize):
        super(Block, self).__init__()
        self.layers = nn.Sequential(
            Conv3dF(insize),
            nn.BatchNorm2d(outsize)
        )
        
    def forward(self, inp):
        return ncrelu(self.layers(inp))

class ENCODER(nn.Module):
    def __init__(self):
        super(ENCODER, self).__init__()
        self.process = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,padding=1),
            Block(64,64),
            nn.MaxPool2d(2),
            Block(128,128),
            nn.MaxPool2d(2),
            Block(256,256),
            nn.MaxPool2d(2),
            Block(512,512),
            nn.MaxPool2d(2),
            Block(1024,1024),
            nn.AvgPool2d((5,5)))
        self.final = nn.Sequential(
            nn.Linear(1024,100)
        )
            
    def forward(self, inp):
        out = self.process(inp)
        return self.final(out.view(-1, 1024))


In [12]:
model = ENCODER()
model.cuda()
nweights = sum([i.numel() for i in list(model.parameters())])
print(nweights," parameters in neural net.")

(114924, ' parameters in neural net.')


In [6]:
standardize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[.485,.456,.406],std=[.229,.224,.225])
            ])

alter = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(84,padding=10),
            standardize
            ])

def batchmaker(theset, way=20, shot=1, alterful=False):
    classes = np.random.choice(len(theset),way)
    if alterful:
        li = [torch.cat([alter(theset[cl][i]).view(1,3,84,84) for i in 
                         np.random.choice(len(theset[0]),shot)],dim=0).float()
              for cl in classes]
    else:
        li = [torch.cat([standardize(theset[cl][i]).view(1,3,84,84) for i in 
                         np.random.choice(len(theset[0]),shot)],dim=0).float()
              for cl in classes]
    support = torch.cat([t[:shot,:,:,:] for t in li],dim=0)
    stargs = torch.LongTensor(np.repeat(classes, shot))
#     query = torch.cat([t[trainshot:,:,:,:] for t in li],dim=0)
#     qtargs = torch.LongTensor([i//shot for i in range(testshot*way)])
    return(Variable(support, volatile=(not alterful)).cuda(),
           Variable(stargs, volatile=(not alterful)).cuda()
          )
# batchmaker(training_data,shot=2)

In [7]:
criterion = nn.CrossEntropyLoss().cuda()
# optimizer = torch.optim.Adam(model.parameters())

In [8]:
def evaluate(model, criterion, testing_data, shot=10, way=100):
    support, targs = batchmaker(testing_data, way=way, shot=shot)
    support = support.view(50, 20, 3, 84, 84)
    targs = targs.view(50, 20)
    acc = 0
    loss = 0
    for i in range(50):
        preds = model(support[i,:,:,:,:])
        loss += criterion(preds, targs[i,:]).data[0]
        _,bins = torch.max(preds, 1)
        acc += torch.sum(torch.eq(bins,targs[i,:])).data[0]
    acc = float(acc)/shot/way
    return loss, acc

In [13]:
%matplotlib inline

vbity = 200
epoch = 2000
start = time.time()
losstracker = []
evalacctracker = []
evallosstracker = []
runningloss = 0
for it in range(10*epoch):
    if it%50==0:
        print(it)
    
    # Build batch
    support, targs = batchmaker(training_data, alterful=True)
    
    # Predict
    model.zero_grad()
    preds = model(support)

    # Calculate Loss
    loss = criterion(preds, targs)
    runningloss += loss.data[0]
    
    # Backprop
    if it%epoch == 0:
        optimizer = torch.optim.Adam(model.parameters(),lr=.01/(2**(it//epoch)))
    loss.backward()
#     nn.utils.clip_grad_norm(model.parameters(), 1)
    optimizer.step()
    
    # Report
    if it%vbity == vbity-1:
        display.clear_output(wait=True)

        losstracker.append(runningloss/vbity)
        model = model.eval()
        evalloss, evalacc = evaluate(model, criterion, testing_data)
        print evalloss, evalacc
        model = model.train()
        evallosstracker.append(evalloss)
        evalacctracker.append(evalacc)
        
        pl.figure(1,figsize=(15,5))
        pl.subplot(1,2,1)
        pl.plot(losstracker)
        pl.plot(evallosstracker)
#         pl.ylim((.5,3))
        pl.title("Loss: Training Blue, Validation Gold")
        pl.subplot(1,2,2)
        pl.plot(evalacctracker[::-1])
#         pl.ylim((0.3,.8))
        pl.title("Validation Acc")
        pl.show()
        
        print("Train loss is: "+str(runningloss/vbity)+
              "\nValidation accuracy is: "+str(evalacc)+
              "\nValidation loss is: "+str(evalloss)+"\n")
        runningloss = 0
        print(time.time()-start)

KeyboardInterrupt: 