In [None]:
from pretrained_models.load_pretrained_models import load_models
from MRL import MRL_Linear_Layer,Matryoshka_CE_Loss

In [None]:
resnet18 = load_models('resnet18')

In [None]:
import torch
import torch.functional as F
import torch.nn as nn

In [None]:
# 去掉最后一层
resnet18 = torch.nn.Sequential(*list(resnet18.children())[:-1])

In [None]:
mrl_linear_layer = MRL_Linear_Layer(nesting_list=[16, 64, 128, 256, 512], num_classes=10)

In [None]:
mrl_resnet18 = torch.nn.Sequential(resnet18,
                                   nn.Flatten(start_dim=1),
                                    mrl_linear_layer)      

In [None]:
# load cifar10
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                            shuffle=False, num_workers=2)


In [None]:
import torch.optim as optim

criterion = Matryoshka_CE_Loss()
optimizer = optim.SGD(mrl_resnet18.parameters(), lr=0.001, momentum=0.9)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mrl_resnet18.to(device)

In [None]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for data, target in data_loader:
        data = data.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        correct_output = 0
        for each_output in output:
            _, predicted = each_output.max(1)
            correct_output += predicted.eq(target).sum().item()
        correct += (correct_output/len(output))
        total += target.size(0)
    return train_loss / len(data_loader), correct / total

In [None]:
def test_epoch(model, data_loader, criterion, device):
    
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            test_loss += loss.item()
            correct_output = 0
            for each_output in output:
                _, predicted = each_output.max(1)
                correct_output += predicted.eq(target).sum().item()
            correct += (correct_output/len(output))
            total += target.size(0)
    return test_loss / len(data_loader), correct / total

In [None]:
epochs = 10

In [None]:
for epoch in range(epochs):
    
    train_loss, train_acc = train_epoch(mrl_resnet18, trainloader, criterion, optimizer, device)
    test_loss, test_acc = test_epoch(mrl_resnet18, testloader, criterion, device)
    
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')