<a href="https://colab.research.google.com/github/meghbhalerao/nrlpq/blob/main/nrlpq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch.nn as nn
import torch
import torchvision
import numpy as np
from torch.utils.data import Subset
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional
import os
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
from torchvision import models
import torch.nn.functional as F

Mounted at /content/gdrive


In [3]:
num_classes = 10

In [4]:
nesting_start=3

class BlurPoolConv2d(torch.nn.Module):
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer('blur_filter', filt)

    def forward(self, x):
        blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
                           groups=self.conv.in_channels, bias=None)
        return self.conv.forward(blurred)

class Model():
    def __init__(self, gpu, nesting, single_head, fixed_feature, use_blurpool):
        super().__init__()
        self.gpu = gpu
        self.nesting = nesting
        self.sh = single_head
        self.ff = fixed_feature
        self.use_blurpool = use_blurpool


    def load_model(self, model, model_weights_disk):
        if os.path.isfile(model_weights_disk):
            print("=> loading checkpoint '{}'".format(model_weights_disk))
            if self.gpu is None:
                checkpoint = torch.load(model_weights_disk)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(self.gpu)
                checkpoint = torch.load(model_weights_disk, map_location=loc)
            model.load_state_dict(checkpoint)
            print("=> loaded checkpoint '{}' "
                  .format(model_weights_disk))
        else:
            print("=> no model found at '{}'".format(model_weights_disk))

        return model

    def initModel(self):
        print("Model init: nesting=%d, sh=%d, ff=%d" %(self.nesting, self.sh, self.ff))
        model = models.resnet50(pretrained=True)
        nesting_list = [2**i for i in range(nesting_start, 12)] if self.nesting else None

        # Nesting/Fixed Feature Modification code block
        if self.nesting:
            ff= "Single Head" if self.sh else "Multi Head"
            print("Using Nesting of type - {}".format(ff))
            print("Nesting Starts from {}".format(2**nesting_start))
            if self.sh:
                model.fc =  SingleHeadNestedLinear(nesting_list, num_classes=num_classes)
            else:
                model.fc =  MultiHeadNestedLinear(nesting_list, num_classes=num_classes)
        elif self.ff != 2048:
            print(f"Using Fixed Features = {self.ff}")
            model.fc =  FixedFeatureLayer(self.ff, num_classes)

        def apply_blurpool(mod: torch.nn.Module):
            for (name, child) in mod.named_children():
                if isinstance(child, torch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
                    setattr(mod, name, BlurPoolConv2d(child))
                else: apply_blurpool(child)
        if self.use_blurpool: apply_blurpool(model)

        model = model.to(memory_format=torch.channels_last)
        model = model.to(self.gpu)

        return model


import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional

class SingleHeadNestedLinear(nn.Linear):
	def __init__(self, nesting_list: List, num_classes=10, **kwargs):
		super(SingleHeadNestedLinear, self).__init__(nesting_list[-1], num_classes, **kwargs)
		self.nesting_list=nesting_list
		self.num_classes=num_classes # Number of classes for classification

	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			if not (self.bias is None):
				logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t()) + self.bias
			else:
				logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t())
			nesting_logits+= (logit,)
		return nesting_logits

class MultiHeadNestedLinear(nn.Module):
	def __init__(self, nesting_list: List, num_classes=num_classes, **kwargs):
		super(MultiHeadNestedLinear, self).__init__()
		self.nesting_list=nesting_list
		self.num_classes=num_classes # Number of classes for classification
		for i, num_feat in enumerate(self.nesting_list):
			setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))		

	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
		return nesting_logits

		
class FixedFeatureLayer(nn.Linear):
    # This layer just takes the first "K" Features for the classification. 
    # Creating a separate layer and customized fwd pass helps to not change the base codes at all.
    def __init__(self, in_features, out_features, **kwargs):
        super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)

    def forward(self, x):
        if not (self.bias is None):
            out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
        else:
            out = torch.matmul(x[:, :self.in_features], self.weight.t())
        return out

class NestedCELoss(nn.Module):
	def __init__(self, **kwargs):
		super(NestedCELoss, self).__init__()
		self.criterion = nn.CrossEntropyLoss(**kwargs)
	def forward(self, output, target):
		loss=0
		for o in output:
			loss+= self.criterion(o, target)
		return loss

In [5]:
mdl_wts = os.path.join("/content/gdrive/MyDrive/nrlpq/Imagenet1k_R50_sh0_mh0_ns3_ff2048/final_weights.pt")
model_wts_path = os.path.join(mdl_wts)
nesting = 1
single_head = 1
fixed_feature = 2048
model = Model(0, nesting, single_head, fixed_feature, use_blurpool=1)


model_init = model.initModel()
model = model.load_model(model_init, model_wts_path)
print("Loaded pretrained model: " + str(model_wts_path))


batch_size = 128
print("batch size is", batch_size)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset_all = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
n_alltrain = len(trainset_all)
print("len of all train data is", n_alltrain)

n_train = int(n_alltrain * 0.8)
n_val = n_alltrain - n_train
print("len of train val split is ", n_train, n_val, "respectively")

val_idxs = np.random.choice(n_alltrain, size = n_val ,replace=False)

trainset = Subset(trainset_all, list(set(range(len(trainset_all))) -  set(val_idxs)))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

valset = Subset(trainset_all, val_idxs)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


print("length of train, val and test DataLoader is ", len(train_loader), len(val_loader), len(test_loader))

Model init: nesting=1, sh=1, ff=2048


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Using Nesting of type - Single Head
Nesting Starts from 8
=> no model found at '/content/gdrive/MyDrive/nrlpq/Imagenet1k_R50_sh0_mh0_ns3_ff2048/final_weights.pt'
Loaded pretrained model: /content/gdrive/MyDrive/nrlpq/Imagenet1k_R50_sh0_mh0_ns3_ff2048/final_weights.pt
batch size is 128
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
len of all train data is 50000
len of train val split is  40000 10000 respectively
Files already downloaded and verified
length of train, val and test DataLoader is  313 79 79


In [7]:

def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device):
    train_losses = []
    valid_losses = []

    for epoch in range(0, epochs):
        # training
        model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)
        train_losses.append(train_loss)

        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)
        print(f'EPOCH:{epoch}')
  
    return model, train_losses, valid_losses

def train(train_loader, model, criterion, optimizer, device):
    model.train()
    running_loss = 0
    for X, y in train_loader:
        optimizer.zero_grad()
        X = X.to(device)
        y = y.to(device)
        pred = model(X) 
      #  print(pred.shape)
        print(y.shape)
        
        loss = criterion(pred, y) 
        running_loss += loss.item() * X.size(0)
        loss.backward()
        optimizer.step()
        print(loss.cpu().data.item())
        break
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

def validate(valid_loader, model, criterion, device):
    model.eval()
    running_loss = 0
    correct_count, all_count = 0, 0
    
    for X, y in valid_loader:
        X = X.to(device)
        y = y.to(device)
        pred = model(X) 
        
        loss = criterion(pred, y) 
        running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(valid_loader.dataset)
    return model, epoch_loss

def get_accuracy(model, data_loader):
  correct_count, all_count = 0, 0
  pred_prob = 0
  i=0
  first_batch = None
  init_pred_label = None
  for images,labels in data_loader:
    images,labels = images.to(device), labels.to(device)
    with torch.no_grad():
        logps = model(images)
    ps = F.softmax(logps, dim=1)
    if device!='cpu':
      probab = list(ps.cpu().numpy()[0])
      pred_label = ps.argmax(axis=1)
      if i==0:
        pred_prob = ps.max(axis=1)
        init_pred_label = pred_label
        first_batch = (images, labels)
        i+=1
      true_label = labels
    else:
      probab = list(ps.numpy()[0])
      pred_label = ps.argmax(axis=1)
      if i==0:
        pred_prob = ps.max(axis=1)
        init_pred_label = pred_label
        first_batch = (images, labels)
        i+=1
      true_label = labels
    correct_count += torch.eq(true_label, pred_label).sum().item()
    all_count += len(true_label)
  return correct_count/all_count, [i.item() for i in list(pred_prob[0])[:10]], init_pred_label, first_batch

def get_accuracy(model,data_loader):
  correct_count, all_count = 0, 0
  for images,labels in data_loader:
    images,labels = images.to(device), labels.to(device)
    images = images.to(device)
    labels = labels.to(device)
    pred = model(images) 
    

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = NestedCELoss()
epochs = 2
device = torch.device('cuda')
training_loop(model, criterion, optimizer, train_loader, val_loader, epochs, device)

torch.Size([128])
20.396411895751953
EPOCH:0
torch.Size([128])
19.705293655395508
EPOCH:1


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): Bottleneck(
       (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (downsample): Sequential(
         (0): Conv2d(64, 256, kernel_size=(1,

In [19]:
res = get_accuracy(model, val_loader)
res[0]

torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128, 1000])
torch.Size([128,

TypeError: ignored

In [20]:
model.summary()

AttributeError: ignored

In [41]:
 for images,labels in test_loader:
   print(images.shape)
   print(labels)
   print(F.one_hot(labels))
   break

torch.Size([128, 3, 32, 32])
tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
        8, 3, 1, 2, 8, 0, 8, 3])
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 1, 0],
        ...,
        [1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
