In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import timm
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split

import os
import json
import argparse

from utils import progress_bar

In [None]:
info = {}
train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []
BATCH_SIZE = 32

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_param = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [None]:
# Data
print('==> Preparing data..')
transform_full = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = torchvision.datasets.Caltech101(
    root='../data', target_type='category', download=True, transform=transform_full)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

In [None]:
pretrain_model = timm.create_model('seresnet34', num_classes=101)
pretrain_model = pretrain_model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(pretrain_model.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

In [None]:
# Training
def train(epoch, net, net_name, FMCE=False):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        if FMCE:
            outputs,feature_maps = net(inputs)
        elif FMCE == False and net_name.startswith('DueHeadNet'):
            outputs,_ = net(inputs)
        else:
            outputs = net(inputs)
        loss = criterion(outputs, targets)
        if FMCE:
            feature_maps_a, feature_maps_b = feature_maps[0], feature_maps[1]
            feature_maps_loss = F.cosine_similarity(feature_maps_b,feature_maps_a, dim=1).mean()
            loss += feature_maps_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        #              % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        print(f"Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} ({correct}/{total})", end='\r')
        
    train_loss_history.append(train_loss / len(trainloader))
    train_acc_history.append(100.*correct/total)

In [None]:
def test(epoch, net, net_name, FMCE=False):
    global best_acc
    global best_param
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            if FMCE:
                outputs, feature_maps = net(inputs)
            elif FMCE == False and net_name.startswith('DueHeadNet'):
                outputs,_ = net(inputs)
            else:
                outputs = net(inputs)
            loss = criterion(outputs, targets)
            if FMCE:
                feature_maps_a, feature_maps_b = feature_maps[0], feature_maps[1]
                feature_maps_loss = F.cosine_similarity(feature_maps_b,feature_maps_a, dim=1).mean()
                loss += feature_maps_loss
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            #              % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            print(f"Loss: {test_loss/(batch_idx+1)} | Acc: {100.*correct/total} ({correct}/{total})", end='\r')
            
    test_loss_history.append(test_loss / len(testloader))
    test_acc_history.append(100.*correct/total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print()
        print('New best model found!')
        best_acc = acc
        best_param = net.state_dict()

In [None]:
for epoch in range(start_epoch, start_epoch+20):
    print()
    train(epoch, pretrain_model, 'seresnet34')
    print()
    test(epoch, pretrain_model, 'seresnet34')
    scheduler.step()

In [None]:
pretrain_model.load_state_dict(best_param)

In [None]:
class DueHeadNet(nn.Module):
    def __init__(self, num_classes=1011, base_model="seresnet34", pretrain_model=None):
        super(DueHeadNet, self).__init__()
        self.pretrain_model = pretrain_model
        self.model2 = timm.create_model(base_model, num_classes=num_classes)
        self.feature_table = {
            "seresnet18": 512*7*7,
            "seresnet34": 512*7*7,
            "seresnet50": 2048*7*7,
            "seresnet101": 2048*7*7,
            "seresnet152": 2048*7*7
        }
        self.cls = nn.Sequential(
            nn.Linear(self.feature_table[base_model], 512),
            nn.GELU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        feature_maps_list = []
        with torch.no_grad():
            fe_map_a = self.pretrain_model.forward_features(x)
        fe_map_b = self.model2.forward_features(x)
        feature_maps_list.append(fe_map_a)
        feature_maps_list.append(fe_map_b)
        # feature_maps = torch.stack(feature_maps_list, dim=1)
        feature_maps = fe_map_b + fe_map_a
        feature_maps = feature_maps.view(feature_maps.size(0), -1)
        logits = self.cls(feature_maps)
        return logits, feature_maps_list

In [None]:
new_model = DueHeadNet(num_classes=101, base_model='seresnet34', pretrain_model=pretrain_model)
new_model = new_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(new_model.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

best_acc = 0

for epoch in range(start_epoch, start_epoch+45):
    print()
    train(epoch, new_model, 'DueHeadNet(seresnet34)', FMCE=True)
    print()
    test(epoch, new_model, 'DueHeadNet(seresnet34)', FMCE=True)
    scheduler.step()

In [None]:
print('Best accuracy:', best_acc)

In [None]:
# params comparison
print("seresnet34 params: ", sum(p.numel() for p in pretrain_model.parameters()))
print("DueHeadNet params: ", sum(p.numel() for p in new_model.parameters()))

In [None]:
best_acc

In [None]:
info['DueHeadNet(seresnet34)'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in new_model.parameters())
}

In [None]:
new_model_FMCE = DueHeadNet(num_classes=101, base_model='seresnet34', pretrain_model=pretrain_model)
new_model_FMCE = new_model_FMCE.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(new_model_FMCE.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

best_acc = 0

for epoch in range(start_epoch, start_epoch+45):
    print()
    train(epoch, new_model_FMCE, 'DueHeadNet(seresnet34)NFMCE', FMCE=False)
    print()
    test(epoch, new_model_FMCE, 'DueHeadNet(seresnet34)NFMCE', FMCE=False)
    scheduler.step()

print('Best accuracy:', best_acc)

In [None]:
best_acc

In [None]:
info['DueHeadNet(seresnet34)NFMCE'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in new_model.parameters())
}

In [None]:
# train a single seresnet101 to compare

seresnet34 = timm.create_model('seresnet50', num_classes=101)
seresnet34 = seresnet34.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(seresnet34.parameters(), lr=4e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60)    
best_acc = 0
for epoch in range(start_epoch, start_epoch+30):
    print()
    train(epoch, seresnet34, 'seresnet50')
    print()
    test(epoch, seresnet34, 'seresnet50')
    scheduler.step()

print("seresnet50 params: ", sum(p.numel() for p in seresnet34.parameters()))
print(best_acc)

In [None]:
info['seresnet50'] = {
    'best_acc': best_acc,
    'parms': sum(p.numel() for p in seresnet34.parameters())
}

In [None]:
info

In [None]:
with open('info.json', 'w') as f:
    json.dump(info, f)