In [1]:
# Import necessary packages.
import numpy as np
import pandas as pd
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, VisionDataset
from torchsummary import summary
from models.student import *

# This is for the progress bar.
from tqdm.auto import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = {
    'dataset_root': './Food-11',
    'save_dir': './outputs',
    'exp_name': "distill_small_10000",
    'batch_size': 64,
    'lr': 3e-4,
    'num_workers': 4,
    'seed': 20220013,
    'loss_fn_type': 'CE', # simple baseline: CE, medium baseline: KD. See the Knowledge_Distillation part for more information.
    'weight_decay': 1e-5,
    'grad_norm_max': 10,
    'n_epochs': 5000, # train more steps to pass the medium baseline.
    'patience': 10000,
    'alpha': 0.35,
    'beta': 0.35,
    'temperature': 4.0,
    'loss_fn': 'kd_with_features',
    'dim_mapper_keys': ['pre_layer', 'layer1', 'layer2', 'layer3', 'layer4', 'avg_pooled']
}

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
myseed = cfg['seed']  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
random.seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

save_path = os.path.join(cfg['save_dir'], cfg['exp_name']) # create saving directory
os.makedirs(save_path, exist_ok=True)

# define simple logging functionality
log_fw = open(f"{save_path}/log.txt", 'w') # open log file to save log outputs
def log(text):     # define a logging function to trace the training process
    print(text)
    log_fw.write(str(text)+'\n')
    log_fw.flush()

log(cfg)  # log your configs to the log file

{'dataset_root': './Food-11', 'save_dir': './outputs', 'exp_name': 'distill_small_10000', 'batch_size': 64, 'lr': 0.0003, 'num_workers': 4, 'seed': 20220013, 'loss_fn_type': 'CE', 'weight_decay': 1e-05, 'grad_norm_max': 10, 'n_epochs': 5000, 'patience': 10000, 'alpha': 0.35, 'beta': 0.35, 'temperature': 4.0, 'loss_fn': 'kd_with_features', 'dim_mapper_keys': ['pre_layer', 'layer1', 'layer2', 'layer3', 'layer4', 'avg_pooled']}


In [4]:
# # Normally, We don't need augmentations in testing and validation.
# # All we need here is to resize the PIL image and transform it into Tensor.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

test_tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

# # However, it is also possible to use augmentation in the testing phase.
# # You may use train_tfm to produce a variety of images and then test using ensemble methods
policy = transforms.AutoAugmentPolicy.IMAGENET
augmenter = transforms.AutoAugment(policy)
train_tfm = transforms.Compose([
    # transforms.Resize(224),
    # transforms.CenterCrop(224),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

In [5]:
class FoodDataset(Dataset):
    def __init__(self, path, tfm=test_tfm, augmenter=None, files = None):
        super().__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {path} sample",self.files[0])
        self.transform = tfm
        if augmenter != None:
            print('Use augmenter.')
        self.augmenter = augmenter
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        if self.augmenter != None:
            im = self.augmenter(im)
        im = self.transform(im)
        #im = self.data[idx]
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
        return im,label

In [6]:
train_set = FoodDataset(os.path.join(cfg['dataset_root'], "training"), tfm=train_tfm, augmenter=augmenter)
train_loader = DataLoader(train_set, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True)

valid_set = FoodDataset(os.path.join(cfg['dataset_root'], "validation"), tfm=test_tfm)
valid_loader = DataLoader(valid_set, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True)

One ./Food-11/training sample ./Food-11/training/0_0.jpg
Use augmenter.
One ./Food-11/validation sample ./Food-11/validation/0_0.jpg


In [7]:
from torch.utils.tensorboard import SummaryWriter

dim_mapper = {
    'pre_layer': (32, 64),
    'layer1': (32, 64),
    'layer2': (64, 128),
    'layer3': (64, 256),
    'layer4': (64, 512),
    'avg_pooled': (64, 512),
    'pre_logits': (64, 512),
}
dim_mapper_models = {}
for key in cfg['dim_mapper_keys']:
    assert key in dim_mapper.keys()
    in_ch, out_ch = dim_mapper[key]
    if in_ch == out_ch:
        continue
    if key == 'pre_logits':
        raise NotImplementedError
    new_mapper_model = nn.Conv2d(in_ch, out_ch, 1)
    new_mapper_model.to(device)
    dim_mapper_models[key] = new_mapper_model
    nn.init.kaiming_normal_(new_mapper_model.weight, mode="fan_out", nonlinearity="relu")

def set_dim_mapper_models_mode(mode):
    if mode == 'train':
        for key in cfg['dim_mapper_keys']:
            if dim_mapper_models.get(key, None):
                dim_mapper_models[key].train()
    elif mode == 'eval':
        for key in cfg['dim_mapper_keys']:
            if dim_mapper_models.get(key, None):
                dim_mapper_models[key].eval()

writer = SummaryWriter(log_dir=f"{save_path}/tb")

In [None]:
n_epochs = cfg['n_epochs']
patience = cfg['patience'] # If no improvement in 'patience' epochs, early stop

teacher_model = resnet18(num_classes=11, output_whole_layers=True)
teacher_model.load_state_dict(torch.load(f"./pretrain/best.ckpt", map_location='cpu'))
summary(teacher_model, (3, 224, 224), device='cpu')
student_model = resnet_dp_small(num_classes=11, output_whole_layers=True)
summary(student_model, (3, 224, 224), device='cpu')

# # Initialize a model, and put it on the device specified.
teacher_model.to(device)
student_model.to(device)

# # For the classification task, we use cross-entropy as the measurement of performance.
# criterion = nn.CrossEntropyLoss()
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=cfg['alpha'], temperature=cfg['temperature']):
    T = temperature
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
    kl_loss = kl_loss_fn(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T)
    return  kl_loss + F.cross_entropy(outputs, labels) * (1. - alpha)

def loss_fn_l2(outputs, labels, teacher_outputs, alpha=cfg['alpha']):
    l2_loss_fn = nn.MSELoss()
    l2_loss = l2_loss_fn(outputs, teacher_outputs)
    return  l2_loss * alpha + F.cross_entropy(outputs, labels) * (1. - alpha)

def loss_fn_kd_with_features(outputs, labels, teacher_outputs, 
    alpha=cfg['alpha'], beta=cfg['beta'], temperature=cfg['temperature'], 
    dim_mapper_keys=cfg['dim_mapper_keys']):
    student_logits = outputs['logits']
    teacher_logits = teacher_outputs['logits']
    T = temperature
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
    kl_loss = kl_loss_fn(F.log_softmax(student_logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (alpha * T * T)
    CE_loss = F.cross_entropy(student_logits, labels) * (1. - alpha - beta)
    feats_l2_loss = 0.0
    l2_loss_fn = nn.MSELoss()
    for key in dim_mapper_keys:
        student_feats = outputs[key]
        # print(f"{key}: {student_feats.size()}")
        teacher_feats = torch.flatten(teacher_outputs[key], start_dim=1)
        if dim_mapper_models.get(key, None):
            student_feats_teacher_dim = torch.flatten(dim_mapper_models[key](student_feats), start_dim=1)
        else:
            student_feats_teacher_dim = torch.flatten(student_feats, start_dim=1)
        feats_l2_loss += l2_loss_fn(student_feats_teacher_dim, teacher_feats)
    feats_l2_loss *= beta
    return (kl_loss + CE_loss + feats_l2_loss, kl_loss, CE_loss, feats_l2_loss)
    
# torch.autograd.set_detect_anomaly(True)

loss_fn_type = cfg['loss_fn']
loss_fn = eval(f"loss_fn_{loss_fn_type}")

trainable_params = []
for key in cfg['dim_mapper_keys']:
    mapper_model = dim_mapper_models.get(key, None)
    if mapper_model != None:
        trainable_params += list(mapper_model.parameters())
trainable_params += student_model.parameters()
# # Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer = torch.optim.Adam(trainable_params, lr=cfg['lr'], weight_decay=cfg['weight_decay']) 

# # Initialize trackers, these are not parameters and should not be changed
stale = 0
best_acc = 0.0

teacher_model.eval()
for epoch in range(n_epochs):

    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    student_model.train()
    set_dim_mapper_models_mode('train')

    # These are used to record information in training.
    train_total_loss = []
    train_kl_loss = []
    train_CE_loss = []
    train_feats_l2_loss = []
    train_accs = []
    train_lens = []

    for batch in tqdm(train_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        #imgs = imgs.half()
        #print(imgs.shape,labels.shape)

        # Forward the data. (Make sure data and model are on the same device.)
        outputs = student_model(imgs)
        with torch.no_grad():
            teacher_outputs = teacher_model(imgs)

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        total_loss, kl_loss, CE_loss, feats_l2_loss = loss_fn(outputs, labels, teacher_outputs)

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        total_loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=cfg['grad_norm_max'])

        # Update the parameters with computed gradients.
        optimizer.step()

        # Compute the accuracy for current batch.
        acc = (outputs['logits'].argmax(dim=-1) == labels).float().sum()

        # Record the loss and accuracy.
        train_batch_len = len(imgs)
        train_total_loss.append(total_loss.item() * train_batch_len)
        train_kl_loss.append(kl_loss.item() * train_batch_len)
        train_CE_loss.append(CE_loss.item() * train_batch_len)
        train_feats_l2_loss.append(feats_l2_loss.item() * train_batch_len)
        train_accs.append(acc)
        train_lens.append(train_batch_len)
        
    train_total_loss = sum(train_total_loss) / sum(train_lens)
    train_kl_loss = sum(train_kl_loss) / sum(train_lens)
    train_CE_loss = sum(train_CE_loss) / sum(train_lens)
    train_feats_l2_loss = sum(train_feats_l2_loss) / sum(train_lens)
    train_acc = sum(train_accs) / sum(train_lens)

    # Print the information.
    log(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = ({train_total_loss:.5f}, {train_kl_loss:.5f}, {train_CE_loss:.5f}, {train_feats_l2_loss:.5f}), acc = {train_acc:.5f}")
    writer.add_scalar('Loss_total/train', train_total_loss, epoch)
    writer.add_scalar('Loss_KL/train', train_kl_loss, epoch)
    writer.add_scalar('Loss_CE/train', train_CE_loss, epoch)
    writer.add_scalar('Loss_L2_feats/train', train_feats_l2_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    student_model.eval()
    set_dim_mapper_models_mode('eval')

    # These are used to record information in validation.
    valid_total_loss = []
    valid_kl_loss = []
    valid_CE_loss = []
    valid_feats_l2_loss = []
    valid_accs = []
    valid_lens = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        #imgs = imgs.half()

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            outputs = student_model(imgs)
            teacher_outputs = teacher_model(imgs)

        # We can still compute the loss (but not the gradient).
        total_loss, kl_loss, CE_loss, feats_l2_loss = loss_fn(outputs, labels, teacher_outputs)

        # Compute the accuracy for current batch.
        acc = (outputs['logits'].argmax(dim=-1) == labels).float().sum()

        # Record the loss and accuracy.
        batch_len = len(imgs)
        valid_total_loss.append(total_loss.item() * train_batch_len)
        valid_kl_loss.append(kl_loss.item() * train_batch_len)
        valid_CE_loss.append(CE_loss.item() * train_batch_len)
        valid_feats_l2_loss.append(feats_l2_loss.item() * train_batch_len)
        valid_accs.append(acc)
        valid_lens.append(batch_len)
        #break

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_total_loss = sum(valid_total_loss) / sum(valid_lens)
    valid_kl_loss = sum(valid_kl_loss) / sum(valid_lens)
    valid_CE_loss = sum(valid_CE_loss) / sum(valid_lens)
    valid_feats_l2_loss = sum(valid_feats_l2_loss) / sum(valid_lens)
    valid_acc = sum(valid_accs) / sum(valid_lens)
    writer.add_scalar('Loss_total/valid', train_total_loss, epoch)
    writer.add_scalar('Loss_KL/valid', valid_kl_loss, epoch)
    writer.add_scalar('Loss_CE/valid', valid_CE_loss, epoch)
    writer.add_scalar('Loss_L2_feats/valid', valid_feats_l2_loss, epoch)
    writer.add_scalar('Accuracy/valid', valid_acc, epoch)

    # update logs
    
    if valid_acc > best_acc:
        log(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = ({valid_total_loss:.5f}, {valid_kl_loss:.5f}, {valid_CE_loss:.5f}, {valid_feats_l2_loss:.5f}), acc = {valid_acc:.5f} -> best")
        writer.add_scalar('BestAcc/valid', valid_acc, epoch)
    else:
        log(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = ({valid_total_loss:.5f}, {valid_kl_loss:.5f}, {valid_CE_loss:.5f}, {valid_feats_l2_loss:.5f}), acc = {valid_acc:.5f}")


    # save models
    if valid_acc > best_acc:
        log(f"Best model found at epoch {epoch}, saving model")
        torch.save(student_model.state_dict(), f"{save_path}/student_best.ckpt") # only save best to prevent output memory exceed error
        best_acc = valid_acc
        stale = 0
    else:
        stale += 1
        if stale > patience:
            log(f"No improvment {patience} consecutive epochs, early stopping")
            break
    log_fw.flush()
log("Finish training")
log_fw.close()

In [8]:
# create dataloader for evaluation
eval_set = FoodDataset(os.path.join(cfg['dataset_root'], "evaluation"), tfm=test_tfm)
eval_loader = DataLoader(eval_set, batch_size=cfg['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

One ./Food-11/evaluation sample ./Food-11/evaluation/0000.jpg


In [20]:
# Load model from {exp_name}/student_best.ckpt
student_model_best = resnet_dp_small(num_classes=11, output_whole_layers=True) # get a new student model to avoid reference before assignment.
ckpt_path = f"{save_path}/student_best.ckpt" # the ckpt path of the best student model.
student_model_best.load_state_dict(torch.load(ckpt_path, map_location='cpu')) # load the state dict and set it to the student model
student_model_best.to(device) # set the student model to device

# Start evaluate
student_model_best.eval()
eval_preds = [] # storing predictions of the evaluation dataset

# Iterate the validation set by batches.
for batch in tqdm(eval_loader):
    # A batch consists of image data and corresponding labels.
    imgs, _ = batch
    # We don't need gradient in evaluation.
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        # print(logits)
        logits = student_model_best(imgs.to(device))
        logits = logits['logits']
        preds = list(logits.argmax(dim=-1).squeeze().cpu().numpy())
    # loss and acc can not be calculated because we do not have the true labels of the evaluation set.
    eval_preds += preds

def pad4(i):
    return "0"*(4-len(str(i))) + str(i)

# Save prediction results
ids = [pad4(i) for i in range(0,len(eval_set))]
categories = eval_preds

df = pd.DataFrame()
df['Id'] = ids
df['Category'] = categories
df.to_csv(f"{save_path}/submission.csv", index=False) # now you can download the submission.csv and upload it to the kaggle competition.

  0%|          | 0/35 [00:00<?, ?it/s]

100%|██████████| 35/35 [00:08<00:00,  4.26it/s]
