In [1]:
from google.colab import drive
drive.mount ('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master

/content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master')

In [4]:
import argparse
import os

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm as tqdm_notebook
from datasets import DataManager
from utils import *
from models import get_model

In [5]:
seed_everything(43)

In [7]:
dataset = 'c100'
model = 'r164'
budget_type = 'channel_ratio'
Vc = 0.5
batch_size = 128
epochs = 10
name = "r164"
host_name = None
valid_size = 0.1
lr = 0.05
scheduler_type = 1
decay = 0.001
test_only = False
workers = 0
cuda_id = 0

In [14]:
vc = 0.25
vc = torch.FloatTensor([vc])

In [9]:
Vc = torch.FloatTensor([Vc])
if host_name == None:
    model_path = f"checkpoints/r164_c10_{str(np.round(vc.item(),decimals=6))}_channel_ratio_pruned.pth"
else:
#     model_path = f"checkpoints/{args.name}_pretrained.pth"
    model_path = f"checkpoints/{host_name}_pruned.pth"

In [10]:
data_object = DataManager(dataset, batch_size, workers, valid_size)
trainloader, valloader, testloader = data_object.prepare_data()
dataloaders = {
        'train': trainloader, 'val': valloader, "test": testloader
}

... Preparing data ...
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
using fixed split
45000 5000


In [11]:
model = get_model(model, 'prune', data_object.num_classes, data_object.insize)
if host_name is not None:
    host_state = torch.load(model_path)['state_dict']
    model.load_state_dict(get_mask_dict(model.state_dict(), host_state), strict = False)
else:
    state = torch.load(model_path)['state_dict']
    model.load_state_dict(state, strict=False)
CE = nn.CrossEntropyLoss()
def criterion(model, y_pred, y_true):
    ce_loss = CE(y_pred, y_true)
    return ce_loss

RuntimeError: ignored

In [None]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=decay)
device = torch.device(f"cuda:{str(cuda_id)}")
model.to(device)
Vc.to(device)

In [None]:
def train(model, loss_fn, optimizer):
    model.train()
    counter = 0
    tk1 = tqdm_notebook(dataloaders['train'], total=len(dataloaders['train']))
    running_loss = 0.
    for x_var, y_var in tk1:
        counter +=1
        x_var = x_var.to(device=device)
        y_var = y_var.to(device=device)
        scores = model(x_var)
        loss = loss_fn(model, scores, y_var)
        running_loss+=loss.item()
        tk1.set_postfix(loss=running_loss/counter)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return running_loss/counter    

In [None]:
def test(model, loss_fn, optimizer, phase):
    model.eval()
    counter = 0
    tk1 = tqdm_notebook(dataloaders[phase], total=len(dataloaders[phase]))
    running_loss = 0
    running_acc = 0
    total = 0
    with torch.no_grad():
        for x_var, y_var in tk1:
            counter +=1
            x_var = x_var.to(device=device)
            y_var = y_var.to(device=device)
            scores = model(x_var)
            loss = loss_fn(model, scores, y_var)
            _, scores = torch.max(scores.data, 1)
            y_var = y_var.cpu().detach().numpy()
            scores = scores.cpu().detach().numpy()

            correct = (scores == y_var).sum().item()
            running_loss+=loss.item()
            running_acc+=correct
            total+=scores.shape[0]
            tk1.set_postfix(loss=running_loss/counter, acc=running_acc/total)
    return running_acc/total, running_loss/counter

In [None]:
model.prepare_for_finetuning(device, Vc.item(), budget_type=budget_type) # sets beta and gamma and unfreezes network except zetas

In [None]:
best_accuracy=0
num_epochs = epochs
train_losses = []
valid_losses = []
valid_accuracy = []
if test_only == False:
    for epoch in range(num_epochs):
        adjust_learning_rate(optimizer, epoch, scheduler_type, lr, epochs)
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        train_loss = train(model, criterion, optimizer)
        accuracy, valid_loss = test(model, criterion, optimizer, "val")
        remaining = model.get_remaining(20.,budget_type).item()
        
        if accuracy>best_accuracy:
            print("**Saving model**")
            best_accuracy=accuracy
            torch.save({
                "epoch": epoch + 1,
                "state_dict" : model.state_dict(),
                "acc" : best_accuracy,
                "rem" : remaining,
            }, f"checkpoints/r164_c10_finetuned.pth")
            
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_accuracy.append(accuracy)
        df_data=np.array([train_losses, valid_losses, valid_accuracy]).T
        df = pd.DataFrame(df_data,columns = ['train_losses','valid_losses','valid_accuracy'])
        df.to_csv(f"logs/r164_c10_finetuned.csv")

Starting epoch 1 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.59]
100%|██████████| 40/40 [00:09<00:00,  4.33it/s, acc=0.449, loss=2.07]


**Saving model**
Starting epoch 2 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.56]
100%|██████████| 40/40 [00:09<00:00,  4.29it/s, acc=0.496, loss=1.8]


**Saving model**
Starting epoch 3 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.58]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.442, loss=2.13]


Starting epoch 4 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.6]
100%|██████████| 40/40 [00:09<00:00,  4.32it/s, acc=0.483, loss=1.92]


Starting epoch 5 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.62]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.448, loss=2.09]


Starting epoch 6 / 50


100%|██████████| 352/352 [06:00<00:00,  1.02s/it, loss=1.63]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.475, loss=1.97]


Starting epoch 7 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.64]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.419, loss=2.24]


Starting epoch 8 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.65]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.462, loss=2.01]


Starting epoch 9 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.64]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.454, loss=2.09]


Starting epoch 10 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.63]
100%|██████████| 40/40 [00:09<00:00,  4.33it/s, acc=0.47, loss=1.98]


Starting epoch 11 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.63]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.415, loss=2.25]


Starting epoch 12 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.61]
100%|██████████| 40/40 [00:09<00:00,  4.32it/s, acc=0.462, loss=2]


Starting epoch 13 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.6]
100%|██████████| 40/40 [00:09<00:00,  4.32it/s, acc=0.448, loss=2.04]


Starting epoch 14 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.6]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.44, loss=2.12]


Starting epoch 15 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.58]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.456, loss=2.05]


Starting epoch 16 / 50


100%|██████████| 352/352 [06:00<00:00,  1.02s/it, loss=1.57]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.434, loss=2.14]


Starting epoch 17 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.55]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.41, loss=2.28]


Starting epoch 18 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.56]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.477, loss=1.95]


Starting epoch 19 / 50


100%|██████████| 352/352 [06:00<00:00,  1.02s/it, loss=1.53]
100%|██████████| 40/40 [00:09<00:00,  4.33it/s, acc=0.529, loss=1.71]


**Saving model**
Starting epoch 20 / 50


100%|██████████| 352/352 [06:01<00:00,  1.03s/it, loss=1.52]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.456, loss=2.08]


Starting epoch 21 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.51]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.427, loss=2.22]


Starting epoch 22 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.5]
100%|██████████| 40/40 [00:09<00:00,  4.29it/s, acc=0.503, loss=1.83]


Starting epoch 23 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.51]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.452, loss=2.12]


Starting epoch 24 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.48]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.463, loss=2.1]


Starting epoch 25 / 50


100%|██████████| 352/352 [06:01<00:00,  1.03s/it, loss=1.48]
100%|██████████| 40/40 [00:09<00:00,  4.30it/s, acc=0.497, loss=1.88]


Starting epoch 26 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.48]
100%|██████████| 40/40 [00:09<00:00,  4.29it/s, acc=0.508, loss=1.82]


Starting epoch 27 / 50


100%|██████████| 352/352 [06:02<00:00,  1.03s/it, loss=1.46]
100%|██████████| 40/40 [00:09<00:00,  4.28it/s, acc=0.403, loss=2.35]


Starting epoch 28 / 50


100%|██████████| 352/352 [06:00<00:00,  1.02s/it, loss=1.47]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.521, loss=1.79]


Starting epoch 29 / 50


100%|██████████| 352/352 [05:58<00:00,  1.02s/it, loss=1.47]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.481, loss=1.91]


Starting epoch 30 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.45]
100%|██████████| 40/40 [00:09<00:00,  4.33it/s, acc=0.423, loss=2.28]


Starting epoch 31 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.15]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.607, loss=1.45]


**Saving model**
Starting epoch 32 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.11]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.605, loss=1.41]


Starting epoch 33 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.11]
100%|██████████| 40/40 [00:09<00:00,  4.32it/s, acc=0.582, loss=1.51]


Starting epoch 34 / 50


100%|██████████| 352/352 [05:58<00:00,  1.02s/it, loss=1.12]
100%|██████████| 40/40 [00:09<00:00,  4.33it/s, acc=0.578, loss=1.5]


Starting epoch 35 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.12]
100%|██████████| 40/40 [00:09<00:00,  4.35it/s, acc=0.576, loss=1.5]


Starting epoch 36 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.13]
100%|██████████| 40/40 [00:09<00:00,  4.34it/s, acc=0.573, loss=1.5]


Starting epoch 37 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.12]
100%|██████████| 40/40 [00:09<00:00,  4.31it/s, acc=0.586, loss=1.47]


Starting epoch 38 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.12]
100%|██████████| 40/40 [00:09<00:00,  4.34it/s, acc=0.577, loss=1.48]


Starting epoch 39 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.11]
100%|██████████| 40/40 [00:09<00:00,  4.35it/s, acc=0.534, loss=1.73]


Starting epoch 40 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.12]
100%|██████████| 40/40 [00:09<00:00,  4.37it/s, acc=0.604, loss=1.42]


Starting epoch 41 / 50


100%|██████████| 352/352 [05:59<00:00,  1.02s/it, loss=1.11]
100%|██████████| 40/40 [00:09<00:00,  4.37it/s, acc=0.601, loss=1.4]


Starting epoch 42 / 50


 87%|████████▋ | 305/352 [05:12<00:48,  1.03s/it, loss=1.09]

In [None]:
state = torch.load(f"checkpoints/r164_c10_finetuned.pth")
model.load_state_dict(state['state_dict'],strict=True)
acc, v_loss = test(model, criterion, optimizer, "test")
print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}")