In [307]:
import torch
from torch import nn
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.nn.functional import one_hot
from tqdm import tqdm
import torch.nn.functional as F

In [308]:
class model_conv(nn.Module):
  def __init__(self, input_shape = 32, number_classes = 10, train_switch = "train"):
    super().__init__()
    self.input_shape = input_shape
    self.train_switch = train_switch
    self.num_classes = number_classes

    #(batch, 1 , 32, 32)
    self.conv1 = nn.Conv2d(1, 32, kernel_size = 3, padding=1, bias=False)
    self.relu1 = nn.ReLU()
    self.max_pool1 = nn.MaxPool2d(2)
    #(batch, 32, 16, 16)

    self.conv2 = nn.Conv2d(32, 64, kernel_size = 3, padding=1, bias=False)
    self.relu2 = nn.ReLU()
    self.max_pool2 = nn.MaxPool2d(2)
    #(batch, 64, 8, 8)

    shape = ((self.input_shape/4)**2) * 64

    self.dropout = nn.Dropout()
    self.linear = nn.Linear(4096, out_features = self.num_classes)
    self.softmax = nn.Softmax(dim = 1)

  def forward(self, x_input):
    x1 = self.conv1(x_input)
    x1 = self.relu1(x1)
    x1 = self.max_pool1(x1)

    x2 = self.conv2(x1)
    x2 = self.relu2(x2)
    x2 = self.max_pool2(x2)
    
    x_flat = x2.view(x2.size(0), -1)

    x_linear = self.dropout(x_flat)
    x_linear = self.linear(x_linear)
    x_output = self.softmax(x_linear)

    if self.train_switch == "train":
      return x_output
    elif self.train_switch == "FID":
      return x_flat


In [309]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(500, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x1 = F.relu(F.max_pool2d(self.conv1(x), 2))
        x1 = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x1)), 2))
        x_flatten = x1.view(-1, 500)
        x2 = F.relu(self.fc1(x_flatten))
        x3 = F.dropout(x2, training=self.training)
        x_out = self.fc2(x3)
        return F.log_softmax(x_out), x2
        

In [310]:
def train_model(model, optmizer, loss, dataloader, number_epochs = 10, device = "cuda"):
  for epoch in range(number_epochs):
    losses = []
    print("Epoch: ", epoch)
    for i, data in tqdm(enumerate(dataloader)):
      optmizer.zero_grad()

      x0 = data[0].to(device)
      label = data[1].to(device)

      x_pred,_ = model(x0)
  
      loss_calc = F.nll_loss(x_pred, label)
      losses.append(loss_calc.item())
      loss_calc.backward()
      optmizer.step()
    epoch += 1
    print(np.mean(losses))
    #Saving Checkpoint
    
    EPOCH = epoch
    PATH = "/content/drive/MyDrive/Fifth year/ClearBox/Diffusion_model_training/MNIST_model/saved_FID_model/"+ "FID.pt"       
    torch.save({
        'epoch': EPOCH,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optmizer.state_dict()
        }, PATH)
    print("checkpoint saved")


        



In [311]:
def main():
  transforms = torchvision.transforms.Compose([                                           
      torchvision.transforms.Resize(80),         
      torchvision.transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
      torchvision.transforms.ToTensor(), 
      torchvision.transforms.Normalize(0.5, 0.5)])
                                                        
  dataset_train = MNIST("/content/MNIST_train", download=True, train=True,transform=transforms)
  batch_size = 64
  number_epochs = 10

  dataloader_train = DataLoader(dataset_train, batch_size, drop_last=True)
  size_iterations = len(dataloader_train.dataset)/batch_size
  device = "cuda"

  loss = nn.CrossEntropyLoss()

  model = Net().to(device)

  optmizer = torch.optim.Adam(model.parameters())
  size_iterations = len(dataloader_train.dataset)/batch_size
  print(size_iterations)
  train_model(model, optmizer, loss, dataloader_train)

 




In [312]:
if __name__ == "__main__":
  main()

937.5
Epoch:  0


937it [00:28, 32.61it/s]


0.6249547843188207
checkpoint saved
Epoch:  1


937it [00:28, 32.92it/s]


0.32555840894770727
checkpoint saved
Epoch:  2


937it [00:28, 32.87it/s]


0.2776930052568399
checkpoint saved
Epoch:  3


937it [00:27, 33.56it/s]


0.25204056460140006
checkpoint saved
Epoch:  4


937it [00:27, 33.84it/s]


0.23403437963409573
checkpoint saved
Epoch:  5


937it [00:29, 32.14it/s]


0.22108195028729982
checkpoint saved
Epoch:  6


937it [00:28, 32.37it/s]


0.21342595421541144
checkpoint saved
Epoch:  7


937it [00:28, 33.22it/s]


0.2076955917157161
checkpoint saved
Epoch:  8


937it [00:28, 32.93it/s]


0.20299529471894848
checkpoint saved
Epoch:  9


937it [00:28, 33.42it/s]

0.19719728811573448
checkpoint saved



