In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
trainset=MNIST('data/mnist',download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
testset=MNIST('data/mnist',download=True,transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

In [6]:
X=torch.rand(128,3,32,32)
y=torch.randint(0,10,size=(128,))

In [7]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.nn=nn.Linear(1152,10)
    def forward(self,x):
        return self.nn(x)
def flat(x):
    return torch.reshape(x,(x.shape[0],-1))
class ChainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,32,kernel_size=3)
        self.conv2=nn.Conv2d(32,64,kernel_size=3,padding=1)
        self.conv3=nn.Conv2d(64,128,kernel_size=3,padding=1)
        
    def forward(self,x):
        x=self.conv1(x)
        x=F.relu(x)
       
        x=F.max_pool2d(x,2,2)
        x=self.conv2(x)
        x=F.max_pool2d(x,2,2)
        x=F.relu(x)
        
        x=self.conv3(x)
        x=F.relu(x)
        x=F.max_pool2d(x,2,2)
        return x
class Inception(nn.Module):
    def __init__(self,D,num=1):
        super().__init__()
        self.cin=nn.Conv2d(D,D//4,kernel_size=1)
        self.middle=nn.Sequential(*[nn.Conv2d(D//4,D//4,kernel_size=3,padding=1) for _ in range(num)])
        self.cout=nn.Conv2d(D//4,D,kernel_size=1)
    def forward(self,x):
        x=F.relu(self.cin(x))
        x=F.relu(self.middle(x))
        x=F.relu(self.cout(x))
        return x
class InceptionBlock(nn.Module):
    def __init__(self,D):
        super().__init__()
        self.blk1=Inception(D,1)
        self.blk2=Inception(D,2)
    def forward(self,x):
        branch1=self.blk1(x)
        branch2=self.blk2(x)
        x=torch.cat((branch1,branch2),dim=1)
        return x
class NetworkModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.nn1=nn.Conv2d(1,32,kernel_size=3)
        self.nn2=InceptionBlock(32)
        self.nn3=InceptionBlock(64)
    def forward(self,x):
        x=self.nn1(x)
        x=F.max_pool2d(x,2,2)
        x=self.nn2(x)
        x=F.max_pool2d(x,2,2)
        x=self.nn3(x)
        x=F.max_pool2d(x,2,2)
        return x
def generator(*args):
    for model in args:
        for param in model.parameters():
            yield param

def run_graph(model,X,y=None):
    x=model(X)
    x=flat(x)
    logit=proj(x)
    if y is not None:
      loss=losses(logit,y)
      return loss
    else:
      return logit
def run_eval(model):
  corr=0
  total=0
  for X,y in testloader:
    logit=run_graph(model,X.to('cuda'))
    yhat=torch.argmax(logit,dim=1).cpu()
    corr+=(yhat==y).sum()
    total+=y.shape[0]
  corr=corr.numpy()
  return (corr)/(total)

In [8]:
proj=Head()
m1=ChainModel()
m2=NetworkModel()
losses=nn.CrossEntropyLoss()
optimizer_m1 = optim.Adam(generator(m1,proj), lr=1e-3)
optimizer_m2 = optim.Adam(generator(m2), lr=1e-3)

In [9]:
for i in range(300):
    optimizer_m1.zero_grad()
    loss=run_graph(m1,X,y)
    loss.backward()
    optimizer_m1.step()
    print(loss)

tensor(2.3071, grad_fn=<NllLossBackward>)
tensor(2.2894, grad_fn=<NllLossBackward>)
tensor(2.2813, grad_fn=<NllLossBackward>)
tensor(2.2789, grad_fn=<NllLossBackward>)
tensor(2.2739, grad_fn=<NllLossBackward>)
tensor(2.2693, grad_fn=<NllLossBackward>)
tensor(2.2650, grad_fn=<NllLossBackward>)
tensor(2.2597, grad_fn=<NllLossBackward>)
tensor(2.2542, grad_fn=<NllLossBackward>)
tensor(2.2485, grad_fn=<NllLossBackward>)
tensor(2.2426, grad_fn=<NllLossBackward>)
tensor(2.2361, grad_fn=<NllLossBackward>)
tensor(2.2288, grad_fn=<NllLossBackward>)
tensor(2.2209, grad_fn=<NllLossBackward>)
tensor(2.2122, grad_fn=<NllLossBackward>)
tensor(2.2027, grad_fn=<NllLossBackward>)


KeyboardInterrupt: 