In [1]:
from pretrained_models.load_pretrained_models import load_models
from MRL import MRL_Linear_Layer,Matryoshka_CE_Loss

In [2]:
resnet18 = load_models('resnet18')



In [3]:
import torch
import torch.functional as F
import torch.nn as nn

In [4]:
# 去掉最后一层
resnet18 = torch.nn.Sequential(*list(resnet18.children())[:-1])

In [5]:
mrl_linear_layer = MRL_Linear_Layer(nesting_list=[16, 64, 128, 256, 512], num_classes=10)

In [6]:
mrl_resnet18 = torch.nn.Sequential(resnet18,
                                   nn.Flatten(start_dim=1),
                                    mrl_linear_layer)      

In [7]:
# load cifar10
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                            shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [8]:
import torch.optim as optim

criterion = Matryoshka_CE_Loss()
optimizer = optim.SGD(mrl_resnet18.parameters(), lr=0.001, momentum=0.9)


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mrl_resnet18.to(device)

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [10]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for data, target in data_loader:
        data = data.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        correct_output = 0
        for each_output in output:
            _, predicted = each_output.max(1)
            correct_output += predicted.eq(target).sum().item()
        correct += (correct_output/len(output))
        total += target.size(0)
    return train_loss / len(data_loader), correct / total

In [11]:
def test_epoch(model, data_loader, criterion, device):
    
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            test_loss += loss.item()
            correct_output = 0
            for each_output in output:
                _, predicted = each_output.max(1)
                correct_output += predicted.eq(target).sum().item()
            correct += (correct_output/len(output))
            total += target.size(0)
    return test_loss / len(data_loader), correct / total

In [12]:
epochs = 10

In [13]:
for epoch in range(epochs):
    
    train_loss, train_acc = train_epoch(mrl_resnet18, trainloader, criterion, optimizer, device)
    test_loss, test_acc = test_epoch(mrl_resnet18, testloader, criterion, device)
    
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch [0/10], Train Loss: 4.6095, Train Acc: 0.6935, Test Loss: 3.1628, Test Acc: 0.7892
Epoch [1/10], Train Loss: 2.7216, Train Acc: 0.8183, Test Loss: 2.8782, Test Acc: 0.8089
Epoch [2/10], Train Loss: 2.0603, Train Acc: 0.8619, Test Loss: 2.7673, Test Acc: 0.8172
Epoch [3/10], Train Loss: 1.5777, Train Acc: 0.8943, Test Loss: 2.8843, Test Acc: 0.8171
Epoch [4/10], Train Loss: 1.2574, Train Acc: 0.9147, Test Loss: 2.8718, Test Acc: 0.8213
Epoch [5/10], Train Loss: 1.0098, Train Acc: 0.9325, Test Loss: 2.9126, Test Acc: 0.8219
Epoch [6/10], Train Loss: 0.8397, Train Acc: 0.9433, Test Loss: 3.1314, Test Acc: 0.8217
Epoch [7/10], Train Loss: 0.6778, Train Acc: 0.9538, Test Loss: 3.3508, Test Acc: 0.8174
Epoch [8/10], Train Loss: 0.5597, Train Acc: 0.9622, Test Loss: 3.4976, Test Acc: 0.8163
Epoch [9/10], Train Loss: 0.5023, Train Acc: 0.9663, Test Loss: 3.4370, Test Acc: 0.8228
