In [None]:
%%bash
git clone https://github.com/sunnytqin/no-distillation.git

In [None]:
%%bash
mkdir -p /kaggle/working/expert_results

In [None]:
%%bash
mkdir -p /kaggle/working/data/tiny-imagenet-200

SRC=tiny-imagenet-200

cp -r /kaggle/input/tiny-imagenet-200/$SRC/* /kaggle/working/data/tiny-imagenet-200/

In [None]:
%%bash
cd /kaggle/working/data/tiny-imagenet-200/val

awk '{print $1, $2}' val_annotations.txt | \
while read IMG CLS; do
  mkdir -p images/$CLS
  mv images/$IMG images/$CLS/
done

In [None]:
from collections import defaultdict
import os, torch
from torch import nn

import numpy as np

import os
import sys
from tqdm import tqdm
from datetime import datetime

import sys, os

repo_root = "/kaggle/working/no-distillation"

if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

train_exp = os.path.join(repo_root, "train_expert")
if train_exp not in sys.path:
    sys.path.insert(0, train_exp)

softlabel_dir = os.path.join(repo_root, "softlabel")
if softlabel_dir not in sys.path:
    sys.path.insert(0, softlabel_dir)

# now your imports will resolve correctly:
from softlabel.utils import (
    get_dataset,
    get_network,
    get_daparam,
    TensorDataset,
    ParamDiffAug,
    DiffAugment,
    augment
)

In [None]:
class ConvNet(nn.Module):
    def __init__(self, channel, num_classes, net_width, net_depth, im_size):
        super(ConvNet, self).__init__()

        self.net_depth = net_depth
        self.net_width = net_width
        self.channel = channel
        self.im_size = im_size
        self.shape_feat = [self.net_width, self.im_size[0] // (2 ** self.net_depth), self.im_size[1] // (2 ** self.net_depth)]

        self.convs = nn.ModuleList([
            nn.Conv2d(self.channel if i == 0 else self.net_width, self.net_width, kernel_size=3, padding=1)
            for i in range(self.net_depth)
        ])

        self.norms = nn.ModuleList([
            nn.GroupNorm(self.net_width, self.net_width)
            for _ in range(self.net_depth)
        ])
        
        self.activation = nn.ReLU(inplace=True)

        self.pooling = nn.AvgPool2d(kernel_size=2, stride=2)

        num_feat = self.shape_feat[0] * self.shape_feat[1] * self.shape_feat[2]
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x):

        out = x
        for d in range(self.net_depth):
            out = self.convs[d](out)
            out = self.norms[d](out)
            out = self.activation(out)
            out = self.pooling(out)
            self.shape_feat[1] //= 2
            self.shape_feat[2] //= 2
        
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

# Training Expert

In [None]:
def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(args.device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):
        img =  datum[0].float().to(args.device)
        lab = datum[1].to(args.device)

        if aug:
            if args.dsa:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
            else:
                img = augment(img, args.dc_aug_param, device=args.device)

        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)

        if (mode == 'train' and args.teacher_label) or (mode == 'train' and len(datum[1].shape)> 1) :
            acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), np.argmax(datum[1].cpu().data.numpy(), axis=-1)))
        else:
            acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
            if args.selection_strategy in ['data_efficient_treat_label','data_efficient_treat_image', 'data_efficient_control']:
                # only examine treatment subject
                subject_idx = torch.where(lab == args.treat_subject)[0].cpu().data.numpy()
                n_b = len(subject_idx)
                acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1)[subject_idx], lab.cpu().data.numpy()[subject_idx]))
               
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg, net


def expert_train(train_epochs):
    dataset            = "Tiny"
    subset             = "imagenette"
    model_name         = "ConvNetD4"
    num_experts        = 1
    lr_teacher         = 0.01
    batch_train        = 256
    batch_real         = 256
    dsa                = True
    dsa_strategy       = "color_crop_cutout_flip_scale_rotate"
    train_epochs       = train_epochs # trains up to these amount of epochs, noting params for each epoch
    momentum           = 0.0
    weight_decay       = 0.0
    save_interval      = 1
    data_path          = "/kaggle/working/data/tiny-imagenet-200"
    buffer_path        = "/kaggle/working/no-distillation/results_100_S/"

    device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dsa_param = ParamDiffAug()

    class A: pass
    args = A()
    args.dataset  = dataset
    args.data_path= data_path
    args.batch_real = batch_real
    args.dsa      = dsa
    args.dsa_strategy = dsa_strategy
    args.dsa_param    = dsa_param
    args.zca = False
    args.device = device
    args.teacher_label = False
    args.selection_strategy = "random"

    channel, im_size, num_classes, _, _, _, \
    dst_train, dst_test, testloader, *_ = get_dataset(
        dataset, data_path, batch_real, args=args
    )

    save_dir = os.path.join(buffer_path, dataset, subset, model_name)
    os.makedirs(save_dir, exist_ok=True)

    images, labels = [], []
    for img, lbl in tqdm(dst_train, desc="Loading real data"):
        images.append(img.unsqueeze(0))
        labels.append(lbl)
    images = torch.cat(images, 0).cpu()
    labels = torch.tensor(labels, dtype=torch.long)

    tensor_train = TensorDataset(images, labels)
    trainloader  = torch.utils.data.DataLoader(
        tensor_train, batch_size=batch_train, shuffle=True
    )

    criterion = nn.CrossEntropyLoss().to(device)
    args.dc_aug_param = get_daparam(dataset, model_name, model_name, None)
    args.dc_aug_param["strategy"] = "crop_scale_rotate"

    model_epoch_params = []
    
    for _ in range(num_experts):
        epoch_model = None
        net = get_network(model_name, channel, num_classes, im_size, dist=False).to(device)
        net.train()
        optim = torch.optim.SGD(
            net.parameters(),
            lr=lr_teacher,
            momentum=momentum,
            weight_decay=weight_decay
        )
        lr_schedule = [train_epochs // 2 + 1]

        for e in tqdm(range(train_epochs)):
            train_loss, train_acc, epoch_model = epoch(
                "train", trainloader, net, optim, criterion, args, aug=True
            )
            test_loss,  test_acc, _  = epoch(
                "test",  testloader, net, None,      criterion, args, aug=False
            )
            print(f"[Epoch {e:3d}]   train_acc={train_acc:.4f}   test_acc={test_acc:.4f}")
            model_epoch_params.append([p.detach().cpu() for p in net.parameters()])
        
    return model_epoch_params

In [None]:
expert_models = expert_train(90)

In [None]:
torch.save(expert_models, "/kaggle/working/expert_models.pt")