<a href="https://colab.research.google.com/github/mmsamiei/acv_final_project/blob/main/small_norb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install timm
!pip install tqdm

Collecting timm
[?25l  Downloading https://files.pythonhosted.org/packages/22/c6/ba02d533cec7329323c7d7a317ab49f673846ecef202d4cc40988b6b7786/timm-0.3.4-py3-none-any.whl (244kB)
[K     |█▍                              | 10kB 25.4MB/s eta 0:00:01[K     |██▊                             | 20kB 29.8MB/s eta 0:00:01[K     |████                            | 30kB 20.8MB/s eta 0:00:01[K     |█████▍                          | 40kB 17.0MB/s eta 0:00:01[K     |██████▊                         | 51kB 15.5MB/s eta 0:00:01[K     |████████                        | 61kB 13.8MB/s eta 0:00:01[K     |█████████▍                      | 71kB 12.9MB/s eta 0:00:01[K     |██████████▊                     | 81kB 14.2MB/s eta 0:00:01[K     |████████████                    | 92kB 13.5MB/s eta 0:00:01[K     |█████████████▍                  | 102kB 13.2MB/s eta 0:00:01[K     |██████████████▊                 | 112kB 13.2MB/s eta 0:00:01[K     |████████████████                | 122kB 13.2MB/s e

In [2]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch
import timm
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

# Load Dataset

In [12]:
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
import struct
import math
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

import torch.utils.data as data
from torchvision.datasets.utils import download_url, check_integrity

class smallNORB(data.Dataset):
    """`small NORB <https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/>`_ Dataset.
    Args:
        root (string): Root directory of dataset
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """
    urls_train = [
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz",
            "smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz",
            "66054832f9accfe74a0f4c36a75bc0a2"
        ],
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz",
            "smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz",
            "23c8b86101fbf0904a000b43d3ed2fd9"
        ],
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz",
            "smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz",
            "51dee1210a742582ff607dfd94e332e3"
        ]
    ]
    urls_test = [
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz",
            "smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz",
            "e4ad715691ed5a3a5f138751a4ceb071"
        ],
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz",
            "smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz",
            "5aa791cd7e6016cf957ce9bdb93b8603"
        ],
        [
            "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz",
            "smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz",
            "a9454f3864d7fd4bb3ea7fc3eb84924e"
        ]
    ]

    train_data_file = ["smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat", "8138a0902307b32dfa0025a36dfa45ec"]
    train_labels_file = ["smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat", "fd5120d3f770ad57ebe620eb61a0b633"]
    train_info_file = ["smallnorb-5x46789x9x18x6x2x96x96-training-info.mat", "19faee774120001fc7e17980d6960451"]

    test_data_file =  ["smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat", "e9920b7f7b2869a8f1a12e945b2c166c"]
    test_labels_file = ["smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat", "fd5120d3f770ad57ebe620eb61a0b633"]
    test_info_file = ["smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat", "7c5b871cc69dcadec1bf6a18141f5edc"]

    def __init__(self, root, train=True,
                 transform=None, target_transform=None,
                 download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data = []
            self.train_labels = []
            self.train_info = []

            with open(os.path.join(self.root, self.train_data_file[0]), mode='rb') as f:
                self.train_data = self._parse_small_norb_data(f)
            with open(os.path.join(self.root, self.train_labels_file[0]), mode='rb') as f:
                self.train_labels = self._parse_small_norb_labels(f)
            with open(os.path.join(self.root, self.train_info_file[0]), mode='rb') as f:
                self.train_info = self._parse_small_norb_info(f)

            self.train_data = self.train_data[::2]

            indices = self.train_info[:,2]<10
            self.train_data = [self.train_data[i] for i,v in enumerate(indices) if v]
            self.train_labels = [self.train_labels[i] for i,v in enumerate(indices) if v]
            self.train_info = self.train_info[indices]
          
        else:
            with open(os.path.join(self.root, self.test_data_file[0]), mode='rb') as f:
                self.test_data = self._parse_small_norb_data(f)
            with open(os.path.join(self.root, self.test_labels_file[0]), mode='rb') as f:
                self.test_labels = self._parse_small_norb_labels(f)
            with open(os.path.join(self.root, self.test_info_file[0]), mode='rb') as f:
                self.test_info = self._parse_small_norb_info(f)


            self.test_data = self.test_data[::2]

            indices = self.test_info[:,2]>24
            self.test_data = [self.test_data[i] for i,v in enumerate(indices) if v]
            self.test_labels = [self.test_labels[i] for i,v in enumerate(indices) if v]
            self.test_info = self.test_info[indices]
          

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        dindex = math.floor(index)
        if self.train:
            img, target, info = self.train_data[index], self.train_labels[dindex], self.train_info[dindex]
        else:
            img, target, info = self.test_data[index], self.test_labels[dindex], self.test_info[dindex]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        img = img.convert('RGB')
        img = np.array(img.resize((224,224)))
        img = img.transpose((2, 0, 1))
        target = np.array([target])

        return {'image': torch.from_numpy(img),
                'target': torch.from_numpy(target)}


        #return img, target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_integrity(self):
        root = self.root
        for fentry in ([
            self.train_data_file, 
            self.train_labels_file,
            self.train_info_file,
            self.test_data_file, 
            self.test_labels_file,
            self.test_info_file
        ]):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import gzip
        import shutil

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        root = self.root

        for url in (self.urls_train + self.urls_test):
            download_url(url[0], root, url[1], url[2])

            with gzip.open(os.path.join(root, url[1]), 'rb') as f_in:
                with open(os.path.join(root, url[1][:-3]), 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

    def _parse_small_norb_header(self, file):
        magic = struct.unpack('<BBBB', file.read(4))
        ndims = struct.unpack('<i', file.read(4))[0]
        shape = []
        for _ in range(ndims):
            shape.append(struct.unpack('<i', file.read(4)))
        return {'magic_number' : magic, 'shape' : shape}

    def _parse_small_norb_data(self, file):
        self._parse_small_norb_header(file)
        data = []
        buf  = file.read(9216)
        while len(buf):
            data.append(Image.frombuffer('L', (96,96), buf, 'raw', 'L', 0, 1))
            buf = file.read(9216)
        return data

    def _parse_small_norb_labels(self, file):
        self._parse_small_norb_header(file)
        file.read(8)
        data = []
        buf  = file.read(4)
        while len(buf):
            data.append(struct.unpack('<i', buf)[0])
            buf = file.read(4)
        return data

    def _parse_small_norb_info(self, file):
        self._parse_small_norb_header(file)
        file.read(4)
        instance = []
        elevation = []
        azimuth = []
        lighting = []
        buf  = file.read(4)
        while len(buf):
            instance.append(struct.unpack('<i', buf)[0])
            buf = file.read(4)
            elevation.append(struct.unpack('<i', buf)[0])
            buf = file.read(4)
            azimuth.append(struct.unpack('<i', buf)[0])
            buf = file.read(4)
            lighting.append(struct.unpack('<i', buf)[0])
            buf = file.read(4)
        return np.array([instance, elevation, azimuth, lighting]).transpose()

In [13]:
toy_dataset_train = smallNORB('data/', train=True, download=True)
toy_dataset_test = smallNORB('data/', train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(toy_dataset_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(toy_dataset_test, batch_size=32, shuffle=True)

In [15]:
class CNNNet(torch.nn.Module):
  def __init__(self):
    super(CNNNet, self).__init__()
    self.cnn = timm.create_model('resnet50', pretrained=True, num_classes=0)
    self.fc = torch.nn.Linear(2048, 5)
    self.cnn = self.cnn.eval()
    for param in self.cnn.parameters():
      param.requires_grad = False
    nn.init.xavier_normal_(self.fc.weight)

  def forward(self, x):
    # x = [batch, channel, 224, 224]
    temp = self.cnn(x)
    temp = self.fc(temp)
    return temp

class TransformerNet(torch.nn.Module):
  def __init__(self):
    super(TransformerNet, self).__init__()
    self.transformer = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
    self.fc = torch.nn.Linear(768, 5)
    self.transformer = self.transformer.eval()
    for param in self.transformer.parameters():
      param.requires_grad = False
    nn.init.xavier_normal_(self.fc.weight)

  def forward(self, x):
    # x = [batch, channel, 224, 224]
    temp = self.transformer(x)
    temp = self.fc(temp)
    return temp


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        xavier(m.bias.data)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [16]:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
x = torch.rand(4, 3, 224, 224)
m1 = CNNNet().to(dev)
m2 = TransformerNet().to(dev)

count_parameters(m1), count_parameters(m2)

(10245, 3845)

In [18]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(params =m1.fc.parameters() , lr=1e-4)

In [19]:
def evaluate(model): 
  correct = 0
  total = 0
  model = model.eval()

  with torch.no_grad():
    for i, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
      batch_x = batch['image'].to(dev).float()
      labels = batch['target'].to(dev)
      y_pred = model(batch_x)
      _, predicted = torch.max(y_pred.data, 1)
      total += labels.size(0)
      correct += (predicted == labels.squeeze(1)).sum().item()
  
  return correct / total



In [20]:
MAX_EPOCH = 5

for epoch in range(MAX_EPOCH):
  running_loss = 0.0

  for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
    optimizer.zero_grad()
    batch_x = batch['image'].to(dev)
    batch_y = batch['target'].to(dev)
    y_pred = m1(batch_x.float())
    loss = criterion(y_pred, batch_y.squeeze(1))
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    if i % 200 == 199:    # print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
  accuracy = evaluate(m1)
  m1 = m1.train()
  print('epoch {}, accuracy: {}'.format(epoch, accuracy))

HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))

[1,   200] loss: 35.061



HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))


epoch 0, accuracy: 0.20266666666666666


HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))

[2,   200] loss: 1.096



HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))


epoch 1, accuracy: 0.7204444444444444


HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))

[3,   200] loss: 0.669



HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))


epoch 2, accuracy: 0.758962962962963


HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))

[4,   200] loss: 0.508



HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))


epoch 3, accuracy: 0.7761481481481481


HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))

[5,   200] loss: 0.416



HBox(children=(FloatProgress(value=0.0, max=211.0), HTML(value='')))


epoch 4, accuracy: 0.7844444444444445


In [None]:
optimizer = optim.Adam(params =m2.fc.parameters() , lr=1e-4)

MAX_EPOCH = 5

for epoch in range(MAX_EPOCH):
  running_loss = 0.0

  for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
    optimizer.zero_grad()
    batch_x = batch['image'].to(dev)
    batch_y = batch['target'].to(dev)
    y_pred = m2(batch_x.float())
    loss = criterion(y_pred, batch_y.squeeze(1))
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    if i % 200 == 199:    # print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
  accuracy = evaluate(m2)
  m2 = m2.train()
  print('epoch {}, accuracy: {}'.format(epoch, accuracy))

HBox(children=(FloatProgress(value=0.0, max=1519.0), HTML(value='')))

[1,   200] loss: 1.608
[1,   400] loss: 1.563
[1,   600] loss: 1.538
[1,   800] loss: 1.516
[1,  1000] loss: 1.501
[1,  1200] loss: 1.485
[1,  1400] loss: 1.474



HBox(children=(FloatProgress(value=0.0, max=1519.0), HTML(value='')))


epoch 0, accuracy: 0.2719135802469136


HBox(children=(FloatProgress(value=0.0, max=1519.0), HTML(value='')))

[2,   200] loss: 1.451
[2,   400] loss: 1.447
[2,   600] loss: 1.432
[2,   800] loss: 1.426
[2,  1000] loss: 1.422
[2,  1200] loss: 1.415
[2,  1400] loss: 1.407



HBox(children=(FloatProgress(value=0.0, max=1519.0), HTML(value='')))


epoch 1, accuracy: 0.2666460905349794


HBox(children=(FloatProgress(value=0.0, max=1519.0), HTML(value='')))

[3,   200] loss: 1.403
[3,   400] loss: 1.391
[3,   600] loss: 1.380
[3,   800] loss: 1.382
[3,  1000] loss: 1.381
[3,  1200] loss: 1.372
[3,  1400] loss: 1.367
