In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import simsiam.loader
import simsiam.builder
import simsiam.builder_resnet18_2
import simsiam.resnet18

class Net(nn.Module):
    
    """ Linear evaluation network. """
    
    def __init__(self):
        super().__init__()
        #self.lin1 = nn.Linear(512,50)
        self.lin2 = nn.Linear(512,10)
        #self.lin3 = nn.Linear(20,10)
        #self.lin4 = nn.Linear(20,10)
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
    def forward(self, x):
        #x = F.relu(self.lin1(x))
        #x = F.relu(self.lin2(x))
        #x = F.relu(self.lin3(x))
        x = self.lin2(x)
        return x

In [8]:
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
device = torch.device('cuda') 
model = simsiam.builder_resnet18_2.SimSiam(
        simsiam.resnet18.resnet18(), 2048, 512)

### Load model you want to evaluate ###
model = torch.load("fixlr100.pt")

images = torch.load('data_sup/X_250.pt')
labels = torch.load('data_sup/y_250.pt')

im = []
lab = []

model.to(device)
"""
for i in range(int(len(images)/4)):
    input_ = torch.cat((images[i].cuda(), images[i+1].cuda(), images[i+2].cuda(), images[i+3].cuda()), 0)
    im.append(model.forward_lat_pool(input_).detach())
    lab.append(torch.cat((labels[i].cuda(), labels[i+1].cuda(), labels[i+2].cuda(), labels[i+3].cuda()), 0))
    #input_ = torch.cat((images[i], images[i+1], images[i+2], images[i+3]), 0)
    #im.append(model.forward_lat_pool(input_).detach())
    #lab.append(torch.cat((labels[i], labels[i+1], labels[i+2], labels[i+3]), 0))
   """ 
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])
augmentation = [
    #transforms.RandomResizedCrop(28, scale=(0.2, 1.)),
    #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    #transforms.RandomGrayscale(p=0.2),
    #transforms.RandomApply([simsiam.loader.GaussianBlur([.1, 2.])], p=0.5),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
    ]
data_storage = "./data"
trainset = torchvision.datasets.CIFAR10(root=data_storage, train=True,
                                        download=True, transform=simsiam.loader.TwoCropsTransform(transforms.Compose(augmentation)))

    
train_loader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle= True,
        num_workers=4, pin_memory=True, drop_last=True)


im1 = []
lab1 = []

for images1, labels1 in train_loader:
    input_ = images1[0].cuda()
    im1.append(model.forward_lat_pool(input_).detach())
    lab1.append(labels1.cuda())
    #print(input_.shape)
    
print(len(im1))

Files already downloaded and verified




781


In [9]:
import matplotlib.pyplot as plt
# NOTERA ATT Net() mÃ¥ste ha samma input dim som model har output dim. 
class SimSiamNet:
    """Model loading both trained model and linear evaluation network. """
    def __init__(self, model1, model2):
        super().__init__()
        self.simsiam = model1
        self.classifier = model2
        
    def forward(self, x):
        x = self.simsiam.forward_lat_pool(x).detach()
        x = self.classifier(x)
        return x

transform = transforms.Compose(
    [transforms.CenterCrop(28),
     transforms.ToTensor(),
     transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
def val_acc(simsiam_net):
    correct = 0
    total = 0
    with torch.no_grad():
        times = 0
        for data in testloader:
            times += 1
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            # calculate outputs by running images through the network
            outputs = simsiam_net.forward(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.02, weight_decay=0, momentum=0.9)

val_acc_lst = []
net.to(device)
for epoch in range(90):  # loop over the dataset multiple times
    running_loss = 0.0
    
    for i in range(len(im1)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = im1[i], lab1[i]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        #print(f"Pred: {outputs[0].argmax()}, Label: {labels[0]}")
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
    #if epoch % 1000 == 0:
    #    val_acc_lst.append(val_acc(SimSiamNet(model, net)))
    if epoch % 10 == 0:
        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
        running_loss = 0.0
        
print(val_acc(SimSiamNet(model, net)))

Files already downloaded and verified
[1,   781] loss: 0.717
[11,   781] loss: 0.298
[21,   781] loss: 0.272
[31,   781] loss: 0.261
[41,   781] loss: 0.254
[51,   781] loss: 0.250
[61,   781] loss: 0.247
[71,   781] loss: 0.244
[81,   781] loss: 0.242




0.6451
