In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master

/content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/Transmute.AI/ChipNet/ChipNet-master')

In [4]:
import argparse
import os

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm as tqdm_notebook

from utils import *
from models import get_model
from datasets import DataManager

In [5]:
seed_everything(43)

In [7]:
dataset = 'c100'
model = 'r164'
budget_type = 'channel_ratio'
Vc = 0.25
batch_size = 32
epochs = 5
workers = 0
valid_size = 0.1
lr = 0.001
test_only = False

decay = 0.001
w1 = 30.
w2 = 10.
b_inc = 5.
g_inc = 2.

cuda_id = 0

In [8]:
Vc = torch.FloatTensor([Vc])

In [9]:
data_object = DataManager(dataset, batch_size, workers, valid_size)
trainloader, valloader, testloader = data_object.prepare_data()
dataloaders = {
        'train': trainloader, 'val': valloader, "test": testloader
}

... Preparing data ...
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
using fixed split
45000 5000


In [10]:
model = get_model(model, 'prune', data_object.num_classes, data_object.insize)
state = torch.load(f"checkpoints/r164_c10_pretrained.pth")
model.load_state_dict(state['state_dict'], strict=False)

_IncompatibleKeys(missing_keys=['layer1.0.bn1.zeta', 'layer1.0.bn1.beta', 'layer1.0.bn1.gamma', 'layer1.0.bn2.zeta', 'layer1.0.bn2.beta', 'layer1.0.bn2.gamma', 'layer1.0.bn3.zeta', 'layer1.0.bn3.beta', 'layer1.0.bn3.gamma', 'layer1.1.bn1.zeta', 'layer1.1.bn1.beta', 'layer1.1.bn1.gamma', 'layer1.1.bn2.zeta', 'layer1.1.bn2.beta', 'layer1.1.bn2.gamma', 'layer1.1.bn3.zeta', 'layer1.1.bn3.beta', 'layer1.1.bn3.gamma', 'layer1.2.bn1.zeta', 'layer1.2.bn1.beta', 'layer1.2.bn1.gamma', 'layer1.2.bn2.zeta', 'layer1.2.bn2.beta', 'layer1.2.bn2.gamma', 'layer1.2.bn3.zeta', 'layer1.2.bn3.beta', 'layer1.2.bn3.gamma', 'layer1.3.bn1.zeta', 'layer1.3.bn1.beta', 'layer1.3.bn1.gamma', 'layer1.3.bn2.zeta', 'layer1.3.bn2.beta', 'layer1.3.bn2.gamma', 'layer1.3.bn3.zeta', 'layer1.3.bn3.beta', 'layer1.3.bn3.gamma', 'layer1.4.bn1.zeta', 'layer1.4.bn1.beta', 'layer1.4.bn1.gamma', 'layer1.4.bn2.zeta', 'layer1.4.bn2.beta', 'layer1.4.bn2.gamma', 'layer1.4.bn3.zeta', 'layer1.4.bn3.beta', 'layer1.4.bn3.gamma', 'layer1.

In [11]:
if os.path.exists('logs') == False:
    os.mkdir("logs")

if os.path.exists('checkpoints') == False:
    os.mkdir("checkpoints")

In [12]:
weightage1 = w1 #weightage given to budget loss
weightage2 = w2 #weightage given to crispness loss
steepness = 10. # steepness of gate_approximator


In [19]:
CE = nn.CrossEntropyLoss()
def criterion(model, y_pred, y_true):
    global steepness
    ce_loss = CE(y_pred, y_true)
    budget_loss = ((model.get_remaining(steepness, budget_type).to(device)-Vc.to(device))**2).to(device)
    crispness_loss =  model.get_crispnessLoss(device)
    return budget_loss*weightage1 + crispness_loss*weightage2 + ce_loss

In [20]:
param_optimizer = list(model.named_parameters())
no_decay = ["zeta"]
optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': decay,'lr':lr},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,'lr':lr},
    ]
optimizer = optim.AdamW(optimizer_parameters)

In [21]:
device = torch.device(f"cuda:{str(cuda_id)}")
model.to(device)
Vc.to(device)

tensor([0.2500], device='cuda:0')

In [22]:
def train(model, loss_fn, optimizer, epoch):
    global steepness
    model.train()
    counter = 0
    tk1 = tqdm_notebook(dataloaders['train'], total=len(dataloaders['train']))
    running_loss = 0
    for x_var, y_var in tk1:
        counter +=1
        x_var = x_var.to(device=device)
        y_var = y_var.to(device=device)
        scores = model(x_var)
        loss = loss_fn(model,scores, y_var)
        optimizer.zero_grad()
        loss.backward()
        running_loss+=loss.item()
        tk1.set_postfix(loss=running_loss/counter)
        optimizer.step()
        steepness=min(60,steepness+5./len(tk1))
    return running_loss/counter

In [23]:
def test(model, loss_fn, optimizer, phase, epoch):
    model.eval()
    counter = 0
    tk1 = tqdm_notebook(dataloaders[phase], total=len(dataloaders[phase]))
    running_loss = 0
    running_acc = 0
    total = 0
    with torch.no_grad():
        for x_var, y_var in tk1:
            counter +=1
            x_var = x_var.to(device=device)
            y_var = y_var.to(device=device)
            scores = model(x_var)
            loss = loss_fn(model,scores, y_var)
            _, scores = torch.max(scores.data, 1)
            y_var = y_var.cpu().detach().numpy()
            scores = scores.cpu().detach().numpy()
            
            correct = (scores == y_var).sum().item()
            running_loss+=loss.item()
            running_acc+=correct
            total+=scores.shape[0]
            tk1.set_postfix(loss=(running_loss /counter), acc=(running_acc/total))
    return running_acc/total

In [24]:
best_acc = 0
beta, gamma = 1., 2.
model.set_beta_gamma(beta, gamma)


In [25]:
remaining_before_pruning = []
remaining_after_pruning = []
valid_accuracy = []
pruning_accuracy = []
pruning_threshold = []
# exact_zeros = []
# exact_ones = []
problems = []
name = f'r164_c10_{str(np.round(Vc.item(),decimals=6))}_channel_ratio_pruned'
if test_only == False:
    for epoch in range(epochs):
        print(f'Starting epoch {epoch + 1} / {epochs}')
        model.unprune()
        train(model, criterion, optimizer, epoch)
        print(f'[{epoch + 1} / {epochs}] Validation before pruning')
        acc = test(model, criterion, optimizer, "val", epoch)
        remaining = model.get_remaining(steepness, budget_type).item()
        remaining_before_pruning.append(remaining)
        valid_accuracy.append(acc)
        # exactly_zeros, exactly_ones = model.plot_zt()
        # exact_zeros.append(exactly_zeros)
        # exact_ones.append(exactly_ones)
        
        print(f'[{epoch + 1} / {epochs}] Validation after pruning')
        threshold, problem = model.prune(Vc, budget_type)
        acc = test(model, criterion, optimizer, "val", epoch)
        remaining = model.get_remaining(steepness, budget_type).item()
        pruning_accuracy.append(acc)
        pruning_threshold.append(threshold)
        remaining_after_pruning.append(remaining)
        problems.append(problem)
        
        # 
        beta=min(6., beta+(0.1/b_inc))
        gamma=min(256, gamma*(2**(1./g_inc)))
        model.set_beta_gamma(beta, gamma)
        print("Changed beta to", beta, "changed gamma to", gamma)     
        
        if acc>best_acc:
            print("**Saving checkpoint**")
            best_acc=acc
            torch.save({
                "epoch" : epoch+1,
                "beta" : beta,
                "gamma" : gamma,
                "prune_threshold":threshold,
                "state_dict" : model.state_dict(),
                "accuracy" : acc,
            }, f"checkpoints/{name}.pth")

        df_data=np.array([remaining_before_pruning, remaining_after_pruning, valid_accuracy, pruning_accuracy, pruning_threshold, problems]).T
        df = pd.DataFrame(df_data,columns = ['Remaining before pruning', 'Remaining after pruning', 'Valid accuracy', 'Pruning accuracy', 'Pruning threshold', 'problems'])
        df.to_csv(f"logs/{name}.csv")

Starting epoch 1 / 5


100%|██████████| 1407/1407 [17:57<00:00,  1.31it/s, loss=7.49]


[1 / 5] Validation before pruning


100%|██████████| 157/157 [00:35<00:00,  4.39it/s, acc=0.349, loss=2.9]


[1 / 5] Validation after pruning


100%|██████████| 157/157 [00:23<00:00,  6.69it/s, acc=0.0114, loss=279]


Changed beta to 1.02 changed gamma to 2.8284271247461903
**Saving checkpoint**
Starting epoch 2 / 5


100%|██████████| 1407/1407 [17:42<00:00,  1.32it/s, loss=3.07]


[2 / 5] Validation before pruning


100%|██████████| 157/157 [00:35<00:00,  4.40it/s, acc=0.456, loss=2.58]


[2 / 5] Validation after pruning


100%|██████████| 157/157 [00:22<00:00,  6.84it/s, acc=0.0084, loss=79]


Changed beta to 1.04 changed gamma to 4.000000000000001
Starting epoch 3 / 5


100%|██████████| 1407/1407 [18:15<00:00,  1.28it/s, loss=3.03]


[3 / 5] Validation before pruning


100%|██████████| 157/157 [00:36<00:00,  4.30it/s, acc=0.475, loss=2.75]


[3 / 5] Validation after pruning


100%|██████████| 157/157 [00:24<00:00,  6.46it/s, acc=0.0158, loss=29.2]


Changed beta to 1.06 changed gamma to 5.6568542494923815
**Saving checkpoint**
Starting epoch 4 / 5


100%|██████████| 1407/1407 [18:13<00:00,  1.29it/s, loss=3.11]


[4 / 5] Validation before pruning


100%|██████████| 157/157 [00:35<00:00,  4.44it/s, acc=0.497, loss=2.93]


[4 / 5] Validation after pruning


100%|██████████| 157/157 [00:23<00:00,  6.71it/s, acc=0.0448, loss=12.9]


Changed beta to 1.08 changed gamma to 8.000000000000002
**Saving checkpoint**
Starting epoch 5 / 5


100%|██████████| 1407/1407 [17:39<00:00,  1.33it/s, loss=3.17]


[5 / 5] Validation before pruning


100%|██████████| 157/157 [00:35<00:00,  4.46it/s, acc=0.539, loss=2.91]


[5 / 5] Validation after pruning


100%|██████████| 157/157 [00:23<00:00,  6.77it/s, acc=0.0798, loss=7.09]


Changed beta to 1.1 changed gamma to 11.313708498984763
**Saving checkpoint**
