In [13]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from timm import create_model
from sklearn.metrics import precision_recall_fscore_support, f1_score
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from tqdm.auto import tqdm
import wandb

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
class_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [7]:
teacher_model = create_model('efficientnet_b2', pretrained=False, num_classes=10).to(device)
student_model = create_model('efficientnet_b0', pretrained=False, num_classes=10).to(device)
print('Учитель:', sum(p.numel() for p in teacher_model.parameters()))
print('Студент:', sum(p.numel() for p in student_model.parameters()))

Учитель: 7715084
Студент: 4020358


In [10]:
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.6)
criterion = nn.CrossEntropyLoss()

In [11]:
def train(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

def test(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    precision_class, recall_class, f1_class, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average=None, zero_division=0
    )    
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )
    return running_loss / len(test_loader), precision, recall, f1, precision_class, recall_class, f1_class

In [15]:
wandb.init(
    project='Homework-4',
    name='efficientnet_b2_teacher'
)
model = teacher_model
num_epochs = 25

for i, epoch in enumerate(range(num_epochs)):
    train_loss = train(model, train_loader, criterion, optimizer)
    test_loss, precision, recall, f1, precision_class, recall_class, f1_class = test(model, test_loader, criterion)

    scheduler.step()

    log_data = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'precision_avg': precision,
        'recall_avg': recall,
        'f1_avg': f1,
    }
    
    for class_idx, value in enumerate(precision_class):
        log_data[f'Precision/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(recall_class):
        log_data[f'Recall/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(f1_class):
        log_data[f'F1/{class_labels[class_idx]}'] = value
    
    wandb.log(log_data)

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

In [16]:
wandb.finish()

VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
F1/airplane,▂▁▄▄▁▅▄▆▅▅▆▆▅▆▇▇▇▇▇██████
F1/automobile,▁▂▃▄▄▄▃▅▅▅▆▆▆▇▇▇▇▇▇██████
F1/bird,▁▂▄▂▄▅▅▅▅▄▆▆▆▆▇▇▇▇▇▇█████
F1/cat,▂▁▁▃▃▄▅▅▅▅▅▅▆▆▆▇▇▇▇█▇▇▇██
F1/deer,▂▂▁▃▃▃▃▃▃▅▅▅▆▆▇▇▇▇▇▇█████
F1/dog,▁▂▅▆▆▅▄▃▅▆▆▇▇▇▇▇▇█▇▇█▇███
F1/frog,▁▂▃▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇█████
F1/horse,▁▃▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███████
F1/ship,▂▂▁▃▄▄▄▅▅▅▆▆▆▇▇▇▇▇██▇████
F1/truck,▁▂▃▃▃▄▃▅▅▄▆▆▆▆▆▇▇▇▇█▇████

0,1
F1/airplane,0.68144
F1/automobile,0.76425
F1/bird,0.5625
F1/cat,0.47859
F1/deer,0.61583
F1/dog,0.56313
F1/frog,0.73648
F1/horse,0.70777
F1/ship,0.74218
F1/truck,0.72368


In [65]:
wandb.init(
    project='Homework-4',
    name='efficientnet_b0_student'
)
model = create_model('efficientnet_b0', pretrained=False, num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
num_epochs = 10

for i, epoch in enumerate(range(num_epochs)):
    train_loss = train(model, train_loader, criterion, optimizer)
    test_loss, precision, recall, f1, precision_class, recall_class, f1_class = test(model, test_loader, criterion)

    scheduler.step()

    log_data = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'precision_avg': precision,
        'recall_avg': recall,
        'f1_avg': f1,
    }
    
    for class_idx, value in enumerate(precision_class):
        log_data[f'Precision/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(recall_class):
        log_data[f'Recall/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(f1_class):
        log_data[f'F1/{class_labels[class_idx]}'] = value
    
    wandb.log(log_data)

wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.016 MB uploaded\r'), FloatProgress(value=0.4772499262319268, max=1.0…

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
F1/airplane,▁▁▂▃▆▅█▇█▇
F1/automobile,▁▁▆▆▇█████
F1/bird,▁▃▆▆▇██▇▇█
F1/cat,▄▁▃▇▆▇███▇
F1/deer,▁▇▅█▇▆▇▇█▇
F1/dog,▁▄▄▅▅▄▅█▆█
F1/frog,▁▃▃▄▅▆▇▇██
F1/horse,▁▄▅▆▆▇▇█▇█
F1/ship,▁▅▅▆▇█▇███
F1/truck,▁▄▅▆▆▆▇▇▇█

0,1
F1/airplane,0.49483
F1/automobile,0.57729
F1/bird,0.29774
F1/cat,0.32782
F1/deer,0.33812
F1/dog,0.43397
F1/frog,0.54664
F1/horse,0.51619
F1/ship,0.55035
F1/truck,0.5163


In [24]:
for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model.eval()
None

# Дистилляция логитов

In [25]:
kl_criterion = nn.KLDivLoss(reduction='batchmean')
student_model = create_model('efficientnet_b0', pretrained=False, num_classes=10).to(device)

In [26]:
def train_distill_logit(student_model, teacher_model, train_loader, criterion, optimizer):
    student_model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)
            teacher_probs = nn.functional.softmax(teacher_outputs, dim=1)
        
        student_outputs = student_model(inputs)
        student_probs = nn.functional.log_softmax(student_outputs, dim=1)
        
        distillation_loss = kl_criterion(student_probs, teacher_probs)
        student_loss = criterion(student_outputs, labels)
        loss = 0.5 * distillation_loss + 0.5 * student_loss
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

In [27]:
wandb.init(project='Homework-4', name='distill_logits')

model = student_model
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
num_epochs = 10

for i, epoch in enumerate(range(num_epochs)):
    train_loss = train_distill_logit(model, teacher_model, train_loader, criterion, optimizer)
    test_loss, precision, recall, f1, precision_class, recall_class, f1_class = test(model, test_loader, criterion)

    scheduler.step()

    log_data = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'precision_avg': precision,
        'recall_avg': recall,
        'f1_avg': f1,
    }
    
    for class_idx, value in enumerate(precision_class):
        log_data[f'Precision/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(recall_class):
        log_data[f'Recall/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(f1_class):
        log_data[f'F1/{class_labels[class_idx]}'] = value
    
    wandb.log(log_data)

wandb.finish()

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.017 MB uploaded\r'), FloatProgress(value=0.4760385144429161, max=1.0…

0,1
F1/airplane,▁▃▄▄▆▆▅███
F1/automobile,▁▃▆▆▆▇████
F1/bird,▁▃▅▆▆▆▇███
F1/cat,▁▁▅▅▅▇▇▇█▇
F1/deer,▆▂▁▄▆▇▇▇▇█
F1/dog,▁▆▁▆▇▇▇▇█▇
F1/frog,▁▄▅▅▆▆▆███
F1/horse,▁▄▅▆▇▇█▇██
F1/ship,▁▂▄▅▆▆▇▇██
F1/truck,▁▃▅▆▆▅▇█▇█

0,1
F1/airplane,0.51189
F1/automobile,0.5921
F1/bird,0.36785
F1/cat,0.36299
F1/deer,0.42577
F1/dog,0.41099
F1/frog,0.56291
F1/horse,0.54882
F1/ship,0.58867
F1/truck,0.54089


# Дистилляция скрытых состояний

In [73]:
def get_features(model, x):
    features = []
    def hook(module, input, output):
        features.append(output)
    handle = model.global_pool.register_forward_hook(hook)
    model(x)
    handle.remove()
    return features[0]

In [81]:
class FeatureRegressor(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.regressor(x)

In [82]:
def train_distill_hidden(student_model, teacher_model, train_loader, criterion, optimizer, regressor, optimizer_regressor):
    student_model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_features = get_features(teacher_model, inputs)
        student_features = get_features(student_model, inputs)

        adapted_student_features = regressor(student_features)
        
        student_outputs = student_model(inputs)
        
        feature_loss = -cosine_criterion(adapted_student_features, teacher_features).mean()
        classification_loss = criterion(student_outputs, labels)
        loss = 0.5 * feature_loss + 0.5 * classification_loss
        
        loss.backward()
        optimizer.step()
        optimizer_regressor.step()
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

In [83]:
student_model = create_model('efficientnet_b0', pretrained=False, num_classes=10).to(device)
regressor = FeatureRegressor(1280, 1408).to(device)
optimizer_regressor = optim.Adam(regressor.parameters(), lr=0.001)
cosine_criterion = nn.CosineSimilarity(dim=1)

In [84]:
wandb.init(project='Homework-4', name='distill_hidden')

model = student_model
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
num_epochs = 10

for i, epoch in enumerate(range(num_epochs)):
    train_loss = train_distill_hidden(model, teacher_model, train_loader, criterion, optimizer, regressor, optimizer_regressor)
    test_loss, precision, recall, f1, precision_class, recall_class, f1_class = test(model, test_loader, criterion)

    scheduler.step()

    log_data = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'precision_avg': precision,
        'recall_avg': recall,
        'f1_avg': f1,
    }
    
    for class_idx, value in enumerate(precision_class):
        log_data[f'Precision/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(recall_class):
        log_data[f'Recall/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(f1_class):
        log_data[f'F1/{class_labels[class_idx]}'] = value
    
    wandb.log(log_data)

wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.016 MB uploaded\r'), FloatProgress(value=0.4776453934174826, max=1.0…

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.017 MB uploaded\r'), FloatProgress(value=0.4463183421516755, max=1.0…

0,1
F1/airplane,▃▁▅▆▆▇▇▇▇█
F1/automobile,▁▄▅▅▇▇▇█▇█
F1/bird,▂▂▁▆▇▆███▇
F1/cat,▅▁▃▆▇▅████
F1/deer,▁▅█▆▄▇▇███
F1/dog,▁▅▆▇▇█▆▇▇▇
F1/frog,▃▁▂▅▆▆▇▇▇█
F1/horse,▁▄▅▆▇▇████
F1/ship,▁▄▂▆▅▅▇▇██
F1/truck,▁▂▃▅▅▆▇▆▇█

0,1
F1/airplane,0.52459
F1/automobile,0.57886
F1/bird,0.3099
F1/cat,0.37291
F1/deer,0.36203
F1/dog,0.36137
F1/frog,0.54255
F1/horse,0.52981
F1/ship,0.57254
F1/truck,0.55464


In [32]:
wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.9392258809936453, max=1.0…

# Дистилляция с регрессором

In [66]:
class FeatureRegressor(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Conv2d(in_features, out_features, kernel_size=1),
            nn.BatchNorm2d(out_features),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.regressor(x)

In [67]:
def get_features(model, x):
    features = []
    def hook(module, input, output):
        features.append(output)
    handle = model.conv_head.register_forward_hook(hook)
    model(x)
    handle.remove()
    return features[0]

In [69]:
def train_distill_regressor(
    student_model, teacher_model, train_loader, criterion, optimizer, regressor, optimizer_regressor
):
    student_model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        optimizer_regressor.zero_grad()
        
        with torch.no_grad():
            teacher_features = get_features(teacher_model, inputs)
        student_features = get_features(student_model, inputs)
        
        adapted_student_features = regressor(student_features)
        
        student_outputs = student_model(inputs)
        
        feature_loss = mse_criterion(adapted_student_features, teacher_features)
        classification_loss = criterion(student_outputs, labels)
        loss = 0.1 * feature_loss + 0.9 * classification_loss
        
        loss.backward()
        optimizer.step()
        optimizer_regressor.step()
        
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

In [70]:
student_model = create_model('efficientnet_b0', pretrained=False, num_classes=10).to(device)
regressor = FeatureRegressor(1280, 1408).to(device)
optimizer_regressor = optim.Adam(regressor.parameters(), lr=0.001)
mse_criterion = nn.MSELoss()

In [72]:
wandb.init(project='Homework-4', name='distill_regressor_0.1')

model = student_model
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
num_epochs = 10

for i, epoch in enumerate(range(num_epochs)):
    train_loss = train_distill_regressor(model, teacher_model, train_loader, criterion, optimizer, regressor, optimizer_regressor)
    test_loss, precision, recall, f1, precision_class, recall_class, f1_class = test(model, test_loader, criterion)

    scheduler.step()

    log_data = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'precision_avg': precision,
        'recall_avg': recall,
        'f1_avg': f1,
    }
    
    for class_idx, value in enumerate(precision_class):
        log_data[f'Precision/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(recall_class):
        log_data[f'Recall/{class_labels[class_idx]}'] = value
    
    for class_idx, value in enumerate(f1_class):
        log_data[f'F1/{class_labels[class_idx]}'] = value
    
    wandb.log(log_data)

wandb.finish()

VBox(children=(Label(value='0.016 MB of 0.016 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
F1/airplane,▅▁▂▄▃▆▅▅█▆
F1/automobile,▁▁▁▇▇█▇▆▆█
F1/bird,▁▅▄▅█▇▇▇█▇
F1/cat,▅▅▃▁▇▆▆██▇
F1/deer,▁▂▇▆▁▄▇▇▇█
F1/dog,▁▁▇█▆█▇▅▇█
F1/frog,▁▅▆▆▆▇▇▇██
F1/horse,▁▄▅▆▇▇▇███
F1/ship,▁▄▆▄▅▆▇▇██
F1/truck,▁▃▄▅▅▆▇▇█▇

0,1
F1/airplane,0.43526
F1/automobile,0.55812
F1/bird,0.31316
F1/cat,0.30755
F1/deer,0.38848
F1/dog,0.39135
F1/frog,0.5311
F1/horse,0.50754
F1/ship,0.5741
F1/truck,0.49267
