In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import timm
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import json
import argparse

from models import *
from losses import CosineSimilarityLoss
from utils import progress_bar

In [2]:
info = {}
train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []
BATCH_SIZE = 128

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [4]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = torchvision.datasets.CIFAR10(
    root='../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
pretrain_model = timm.create_model('resnet34')
pretrain_model = pretrain_model.to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
feature_maps_criterion = CosineSimilarityLoss()
optimizer = optim.AdamW(pretrain_model.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [7]:
# Training
def train(epoch, net, net_name, FMCE=False):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        if FMCE:
            outputs,feature_maps = net(inputs)
        elif FMCE == False and net_name.startswith('DueHeadNet'):
            outputs,_ = net(inputs)
        else:
            outputs = net(inputs)
        loss = criterion(outputs, targets)
        if FMCE:
            feature_maps_a, feature_maps_b = feature_maps[0], feature_maps[1]
            feature_maps_loss = feature_maps_criterion(feature_maps_b,feature_maps_a)
            loss += feature_maps_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        #              % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        print(f"Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} ({correct}/{total})", end='\r')
        
    train_loss_history.append(train_loss / len(trainloader))
    train_acc_history.append(100.*correct/total)

In [8]:
def test(epoch, net, net_name, FMCE=False):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            if FMCE:
                outputs, feature_maps = net(inputs)
            elif FMCE == False and net_name.startswith('DueHeadNet'):
                outputs,_ = net(inputs)
            else:
                outputs = net(inputs)
            loss = criterion(outputs, targets)
            if FMCE:
                feature_maps_a, feature_maps_b = feature_maps[0], feature_maps[1]
                feature_maps_loss = feature_maps_criterion(feature_maps_a,feature_maps_b)
                loss += feature_maps_loss
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            #              % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            print(f"Loss: {test_loss/(batch_idx+1)} | Acc: {100.*correct/total} ({correct}/{total})", end='\r')
            
    test_loss_history.append(test_loss / len(testloader))
    test_acc_history.append(100.*correct/total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print()
        print('New best model found!')
        best_acc = acc

In [9]:
for epoch in range(start_epoch, start_epoch+30):
    print()
    train(epoch, pretrain_model, 'ResNet34')
    print()
    test(epoch, pretrain_model, 'ResNet34')
    scheduler.step()



Epoch: 0
Loss: 1.858 | Acc: 38.512 (19256/50000)
Loss: 1.3207306265830994 | Acc: 51.84 (5184/10000)(5130/9900))
New best model found!


Epoch: 1
Loss: 1.262 | Acc: 53.986 (26993/50000)
Loss: 1.1066795098781586 | Acc: 60.13 (6013/10000) (5952/9900)
New best model found!


Epoch: 2
Loss: 1.090 | Acc: 60.768 (30384/50000)
Loss: 1.0333705765008927 | Acc: 63.6 (6360/10000)4 (6303/9900)
New best model found!


Epoch: 3
Loss: 0.985 | Acc: 64.622 (32311/50000)
Loss: 0.9449664986133576 | Acc: 66.48 (6648/10000)6583/9900))
New best model found!


Epoch: 4
Loss: 0.906 | Acc: 67.668 (33834/50000)
Loss: 0.8374721300601959 | Acc: 70.25 (7025/10000)(6957/9900)
New best model found!


Epoch: 5
Loss: 0.844 | Acc: 70.074 (35037/50000)
Loss: 0.8021611362695694 | Acc: 71.66 (7166/10000)(7096/9900)
New best model found!


Epoch: 6
Loss: 0.792 | Acc: 71.876 (35938/50000)
Loss: 0.7898948115110397 | Acc: 72.24 (7224/10000)(7152/9900)
New best model found!


Epoch: 7
Loss: 0.753 | Acc: 73.330 (36665/50000)
L

In [10]:
class DueHeadNet(nn.Module):
    def __init__(self, num_classes=101, base_model="seresnet34", pretrain_model=None):
        super(DueHeadNet, self).__init__()
        self.pretrain_model = pretrain_model
        self.model2 = timm.create_model(base_model, num_classes=num_classes)
        self.feature_table = {
            "seresnet18": 512,
            "seresnet34": 512,
            "seresnet101": 2048,
            "seresnet101": 2048,
            "seresnet152": 2048
        }
        self.cls = nn.Sequential(
            nn.Linear(self.feature_table[base_model]*2, 512),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        feature_maps_list = []
        with torch.no_grad():
            fe_map_a = self.pretrain_model.forward_features(x)
        fe_map_b = self.model2.forward_features(x)
        feature_maps_list.append(fe_map_a)
        feature_maps_list.append(fe_map_b)
        feature_maps = torch.stack(feature_maps_list, dim=1)
        feature_maps = feature_maps.view(feature_maps.size(0), -1)
        logits = self.cls(feature_maps)
        return logits, feature_maps_list

In [11]:
new_model = DueHeadNet(num_classes=10, base_model='seresnet34', pretrain_model=pretrain_model)
new_model = new_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(new_model.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

for epoch in range(start_epoch, start_epoch+30):
    print()
    train(epoch, new_model, 'DueHeadNet(seresnet34)', FMCE=True)
    print()
    test(epoch, new_model, 'DueHeadNet(seresnet34)', FMCE=True)
    scheduler.step()

print('Best accuracy:', best_acc)



Epoch: 0
Loss: 0.802 | Acc: 86.080 (43040/50000)
Loss: 0.8185107636451722 | Acc: 82.7 (8270/10000)(8180/9900))
New best model found!


Epoch: 1
Loss: 0.560 | Acc: 88.344 (44172/50000)
Loss: 0.7630088973045349 | Acc: 82.66 (8266/10000)(8175/9900)

Epoch: 2
Loss: 0.507 | Acc: 88.390 (44195/50000)
Loss: 0.7127496400475501 | Acc: 83.07 (8307/10000)8131/9800))
New best model found!


Epoch: 3
Loss: 0.471 | Acc: 88.744 (44372/50000)
Loss: 0.6837918168306351 | Acc: 83.07 (8307/10000)(8216/9900)

Epoch: 4
Loss: 0.451 | Acc: 89.034 (44517/50000)
Loss: 0.6872548016905785 | Acc: 82.84 (8284/10000)(8194/9900)

Epoch: 5
Loss: 0.430 | Acc: 89.342 (44671/50000)
Loss: 0.6836218824982643 | Acc: 83.19 (8319/10000)(8228/9900)
New best model found!


Epoch: 6
Loss: 0.426 | Acc: 89.204 (44602/50000)
Loss: 0.6861949688196183 | Acc: 83.4 (8340/10000) (8250/9900)
New best model found!


Epoch: 7
Loss: 0.428 | Acc: 89.392 (44696/50000)
Loss: 0.6812883311510086 | Acc: 83.16 (8316/10000)(8225/9900)

Epoch: 8
L

In [12]:
# params comparison
print("ResNet34 params: ", sum(p.numel() for p in pretrain_model.parameters()))
print("DueHeadNet params: ", sum(p.numel() for p in new_model.parameters()))

ResNet34 params:  21797672
DueHeadNet params:  43780148


In [13]:
best_acc

83.7

In [14]:
info['DueHeadNet(seresnet34)'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in new_model.parameters())
}

In [15]:
new_model_FMCE = DueHeadNet(num_classes=10, base_model='seresnet34', pretrain_model=pretrain_model)
new_model_FMCE = new_model_FMCE.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(new_model_FMCE.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

for epoch in range(start_epoch, start_epoch+30):
    print()
    train(epoch, new_model_FMCE, 'DueHeadNet(seresnet34)NFMCE', FMCE=False)
    print()
    test(epoch, new_model_FMCE, 'DueHeadNet(seresnet34)NFMCE', FMCE=False)
    scheduler.step()

print('Best accuracy:', best_acc)



Epoch: 0
Loss: 0.412 | Acc: 86.280 (43140/50000)
Loss: 0.5433474797010421 | Acc: 82.92 (8292/10000)8203/9900))

Epoch: 1
Loss: 0.341 | Acc: 88.132 (44066/50000)
Loss: 0.5500798085331917 | Acc: 83.04 (8304/10000)(8213/9900)

Epoch: 2
Loss: 0.324 | Acc: 88.624 (44312/50000)
Loss: 0.5535410463809967 | Acc: 83.24 (8324/10000)(8235/9900)

Epoch: 3
Loss: 0.319 | Acc: 88.868 (44434/50000)
Loss: 0.5501968100667 | Acc: 83.38 (8338/10000)3 (8247/9900))

Epoch: 4
Loss: 0.318 | Acc: 88.866 (44433/50000)
Loss: 0.5423311738669873 | Acc: 83.5 (8350/10000) (8260/9900)

Epoch: 5
Loss: 0.313 | Acc: 88.890 (44445/50000)
Loss: 0.5515295077860355 | Acc: 83.16 (8316/10000)8226/9900))

Epoch: 6
Loss: 0.311 | Acc: 89.174 (44587/50000)
Loss: 0.557744921296835 | Acc: 83.27 (8327/10000)(8236/9900))

Epoch: 7
Loss: 0.317 | Acc: 88.818 (44409/50000)
Loss: 0.5463577573001385 | Acc: 83.28 (8328/10000)(8238/9900)

Epoch: 8
Loss: 0.307 | Acc: 89.252 (44626/50000)
Loss: 0.5529709309339523 | Acc: 83.42 (8342/10000)(82

In [16]:
best_acc

83.7

In [17]:
info['DueHeadNet(seresnet34)NFMCE'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in new_model.parameters())
}

In [18]:
# train a single seresnet101 to compare
seresnet34 = timm.create_model('seresnet101', num_classes=10)
seresnet34 = seresnet34.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(seresnet34.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)    

for epoch in range(start_epoch, start_epoch+30):
    best_acc = 0
    print()
    train(epoch, seresnet34, 'seresnet101')
    print()
    test(epoch, seresnet34, 'seresnet101')
    scheduler.step()

print("seresnet101 params: ", sum(p.numel() for p in seresnet34.parameters()))
print(best_acc)



Epoch: 0
Loss: 1.723 | Acc: 36.308 (18154/50000)
Loss: 1.3770407843589783 | Acc: 48.44 (4844/10000)(4796/9900))
New best model found!


Epoch: 1
Loss: 1.309 | Acc: 52.136 (26068/50000)
Loss: 1.189563626050949 | Acc: 57.21 (5721/10000) (5662/9900))
New best model found!


Epoch: 2
Loss: 1.121 | Acc: 59.530 (29765/50000)
Loss: 1.023106464743614 | Acc: 63.04 (6304/10000)(6180/9800)0)
New best model found!


Epoch: 3
Loss: 1.007 | Acc: 63.838 (31919/50000)
Loss: 0.9061576277017593 | Acc: 68.17 (6817/10000)(6749/9900)
New best model found!


Epoch: 4
Loss: 0.910 | Acc: 67.394 (33697/50000)
Loss: 0.8659803628921509 | Acc: 69.59 (6959/10000)(6893/9900)
New best model found!


Epoch: 5
Loss: 0.845 | Acc: 69.850 (34925/50000)
Loss: 0.8216940021514892 | Acc: 71.46 (7146/10000)(7076/9900)
New best model found!


Epoch: 6
Loss: 0.785 | Acc: 72.096 (36048/50000)
Loss: 0.8091639101505279 | Acc: 71.82 (7182/10000)(7115/9900)
New best model found!


Epoch: 7
Loss: 0.742 | Acc: 73.732 (36866/50000)
L

In [19]:
info['seresnet101'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in seresnet34.parameters())
}

In [20]:
info

{'DueHeadNet(seresnet34)': {'best_acc': 83.7, 'parms': 43780148},
 'DueHeadNet(seresnet34)NFMCE': {'best_acc': 83.7, 'parms': 43780148},
 'seresnet101': {'best_acc': 83.25, 'parms': 47298362}}

In [21]:
with open('info.json', 'w') as f:
    json.dump(info, f)