<a href="https://colab.research.google.com/github/jiminji0107/SimCLR/blob/master/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
#from torchlars import LARS
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class SimCLRDataTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=SimCLRDataTransform())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

class ResNetSimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.backbone = models.resnet50(weights=None)

        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()

        dim_mlp = self.backbone.fc.in_features


        self.backbone.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp),
            nn.ReLU(),
            nn.Linear(dim_mlp, out_dim)
        )

    def forward(self, x):
        return self.backbone(x)


model = ResNetSimCLR(base_model='resnet18', out_dim=128).to(device)


def nt_xent_loss(out_1, out_2, temperature):
    batch_size = out_1.shape[0]
    out = torch.cat([out_1, out_2], dim=0)
    sim_matrix = F.cosine_similarity(out.unsqueeze(1), out.unsqueeze(0), dim=2)

    sim_ij = torch.diag(sim_matrix, batch_size)
    sim_ji = torch.diag(sim_matrix, -batch_size)

    positives = torch.cat([sim_ij, sim_ji], dim=0)
    negatives = sim_matrix[~torch.eye(2 * batch_size, dtype=bool)].view(2 * batch_size, -1)

    logits = torch.cat((positives.unsqueeze(1), negatives), dim=1)
    labels = torch.zeros(2 * batch_size).long().to(device)

    logits = logits / temperature
    loss = F.cross_entropy(logits, labels)
    return loss



temperature = 0.5
epochs = 100
learning_rate = 3e-4


#optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=1e-4, momentum=0.9)
#optimizer = LARS(model.parameters(), lr=4.8, weight_decay = 1e-6)


for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for step, ((x_i, x_j), _) in enumerate(tqdm(train_loader,desc=f"Epoch [{epoch+1}/{epochs}]")):
        optimizer.zero_grad()
        x_i, x_j = x_i.to(device), x_j.to(device)
        x = torch.cat([x_i, x_j], dim=0)
        h = model(x)
        h_i, h_j = torch.chunk(h, 2)
        #h_i, h_j = model(x_i), model(x_j)
        loss = nt_xent_loss(h_i, h_j, temperature)


        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader)}')

torch.save(model, 'sim_res18_100.pth')

for param in model.backbone.parameters():
    param.requires_grad = False

class Classifier(nn.Module):
    def __init__(self, encoder, num_classes = 10):
        super(Classifier, self).__init__()
        self.encoder = nn.Sequential(*list(encoder.backbone.children())[:-1])
        dim_mlp = encoder.backbone.fc[0].in_features

        self.fc = nn.Linear(dim_mlp, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x).squeeze()
        logits = self.fc(features)
        return logits

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

classifier = Classifier(model).to(device)
train_class = datasets.CIFAR10(root = './data', train=True, download=True, transform = transform)
train_class_load = torch.utils.data.DataLoader(train_class, batch_size = 128, shuffle = True, num_workers = 4)

test_class = datasets.CIFAR10(root = './data', train=True, download=True, transform = transform)
test_class_load = torch.utils.data.DataLoader(test_class, batch_size = 128, shuffle = True, num_workers = 4)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.fc.parameters(), lr=1e-4)

epochs = 20

for epoch in range(epochs):
    classifier.train()
    running_loss = 0.0

    for step, (images, labels) in enumerate(tqdm(train_class_load,desc=f"Epoch [{epoch+1}/{epochs}]")):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()


        outputs = classifier(images)
        loss = criterion(outputs, labels)


        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader)}')

torch.save(classifier, 'class_res18_100.pth')

classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_class_load:
        images, labels = images.to(device), labels.to(device)

        outputs = classifier(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the model on the test images: {accuracy} %')


Using device: cuda
Files already downloaded and verified


Epoch [1/100]: 100%|██████████| 391/391 [02:39<00:00,  2.44it/s]


Epoch [1/100], Loss: 5.544525848935022


Epoch [2/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [2/100], Loss: 5.543976010568916


Epoch [3/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [3/100], Loss: 5.543975439827765


Epoch [4/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [4/100], Loss: 5.543972860516795


Epoch [5/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [5/100], Loss: 5.543767572973695


Epoch [6/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [6/100], Loss: 5.258118787994775


Epoch [7/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [7/100], Loss: 4.970643925239973


Epoch [8/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [8/100], Loss: 4.898660358565543


Epoch [9/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [9/100], Loss: 4.8528522676824


Epoch [10/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [10/100], Loss: 4.82768433539154


Epoch [11/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [11/100], Loss: 4.70300928954883


Epoch [12/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [12/100], Loss: 4.642453649159893


Epoch [13/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [13/100], Loss: 4.562916112982708


Epoch [14/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [14/100], Loss: 4.531945317602523


Epoch [15/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [15/100], Loss: 4.515484084253726


Epoch [16/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [16/100], Loss: 4.500249793157553


Epoch [17/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [17/100], Loss: 4.487564538141041


Epoch [18/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [18/100], Loss: 4.472110278161286


Epoch [19/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [19/100], Loss: 4.446300094389854


Epoch [20/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [20/100], Loss: 4.422378849190519


Epoch [21/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [21/100], Loss: 4.401982102552643


Epoch [22/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [22/100], Loss: 4.3840680402867935


Epoch [23/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [23/100], Loss: 4.371987075756882


Epoch [24/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [24/100], Loss: 4.367396015035527


Epoch [25/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [25/100], Loss: 4.3572033323595285


Epoch [26/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [26/100], Loss: 4.357045758410793


Epoch [27/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [27/100], Loss: 4.347388833380111


Epoch [28/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [28/100], Loss: 4.34333426629186


Epoch [29/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [29/100], Loss: 4.327843765468549


Epoch [30/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [30/100], Loss: 4.324971417941706


Epoch [31/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [31/100], Loss: 4.319269636402959


Epoch [32/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [32/100], Loss: 4.318613971895574


Epoch [33/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [33/100], Loss: 4.3125964229369105


Epoch [34/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [34/100], Loss: 4.311120814984412


Epoch [35/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [35/100], Loss: 4.299902981199572


Epoch [36/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [36/100], Loss: 4.29562010850443


Epoch [37/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [37/100], Loss: 4.292050075652959


Epoch [38/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [38/100], Loss: 4.28975141993569


Epoch [39/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [39/100], Loss: 4.283339526647192


Epoch [40/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [40/100], Loss: 4.276917589290063


Epoch [41/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [41/100], Loss: 4.2754313342101735


Epoch [42/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [42/100], Loss: 4.276196632848676


Epoch [43/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [43/100], Loss: 4.279417694072285


Epoch [44/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [44/100], Loss: 4.268293759402106


Epoch [45/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [45/100], Loss: 4.270716314120671


Epoch [46/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [46/100], Loss: 4.268163847496442


Epoch [47/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [47/100], Loss: 4.267672143019069


Epoch [48/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [48/100], Loss: 4.27283886387525


Epoch [49/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [49/100], Loss: 4.2619782579524434


Epoch [50/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [50/100], Loss: 4.261386643285337


Epoch [51/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [51/100], Loss: 4.262130368396145


Epoch [52/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [52/100], Loss: 4.256666015176212


Epoch [53/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [53/100], Loss: 4.2564932330490075


Epoch [54/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [54/100], Loss: 4.256783287238587


Epoch [55/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [55/100], Loss: 4.251118647168054


Epoch [56/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [56/100], Loss: 4.258285555388311


Epoch [57/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [57/100], Loss: 4.252809506243147


Epoch [58/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [58/100], Loss: 4.247813232109674


Epoch [59/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [59/100], Loss: 4.243687732140426


Epoch [60/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [60/100], Loss: 4.250658148694831


Epoch [61/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [61/100], Loss: 4.244276996768649


Epoch [62/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [62/100], Loss: 4.240784798131879


Epoch [63/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [63/100], Loss: 4.239558809553571


Epoch [64/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [64/100], Loss: 4.240709245052484


Epoch [65/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [65/100], Loss: 4.238108572142814


Epoch [66/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [66/100], Loss: 4.235912666906176


Epoch [67/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [67/100], Loss: 4.240307418891536


Epoch [68/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [68/100], Loss: 4.236482741582729


Epoch [69/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [69/100], Loss: 4.2356399916626915


Epoch [70/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [70/100], Loss: 4.234983738730936


Epoch [71/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [71/100], Loss: 4.233444944976846


Epoch [72/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [72/100], Loss: 4.233754679979875


Epoch [73/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [73/100], Loss: 4.22784568708571


Epoch [74/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [74/100], Loss: 4.231362927600246


Epoch [75/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [75/100], Loss: 4.230527282675819


Epoch [76/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [76/100], Loss: 4.230956796490018


Epoch [77/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [77/100], Loss: 4.229384102174998


Epoch [78/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [78/100], Loss: 4.226311723899354


Epoch [79/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [79/100], Loss: 4.226701590106311


Epoch [80/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [80/100], Loss: 4.229430923071663


Epoch [81/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [81/100], Loss: 4.2231671895517415


Epoch [82/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [82/100], Loss: 4.225689619703366


Epoch [83/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [83/100], Loss: 4.2283494637140535


Epoch [84/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [84/100], Loss: 4.225793028731481


Epoch [85/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [85/100], Loss: 4.225842579551365


Epoch [86/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [86/100], Loss: 4.223516669114837


Epoch [87/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [87/100], Loss: 4.2263389889846374


Epoch [88/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [88/100], Loss: 4.219963036534732


Epoch [89/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [89/100], Loss: 4.221752271018065


Epoch [90/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [90/100], Loss: 4.225395155684723


Epoch [91/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [91/100], Loss: 4.22459090640173


Epoch [92/100]: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s]


Epoch [92/100], Loss: 4.225926624844446


Epoch [93/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [93/100], Loss: 4.2216788941941905


Epoch [94/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [94/100], Loss: 4.219574209369357


Epoch [95/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [95/100], Loss: 4.221156095909645


Epoch [96/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [96/100], Loss: 4.2209239542636725


Epoch [97/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [97/100], Loss: 4.218190392569813


Epoch [98/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [98/100], Loss: 4.220238144745302


Epoch [99/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [99/100], Loss: 4.2148210880396615


Epoch [100/100]: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s]


Epoch [100/100], Loss: 4.222948063365029
Files already downloaded and verified
Files already downloaded and verified


Epoch [1/20]: 100%|██████████| 391/391 [00:23<00:00, 16.44it/s]


Epoch [1/20], Loss: 1.8912117179397427


Epoch [2/20]: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s]


Epoch [2/20], Loss: 1.4919551638386133


Epoch [3/20]: 100%|██████████| 391/391 [00:23<00:00, 16.51it/s]


Epoch [3/20], Loss: 1.3501684784584338


Epoch [4/20]: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s]


Epoch [4/20], Loss: 1.2812214900770456


Epoch [5/20]: 100%|██████████| 391/391 [00:23<00:00, 16.51it/s]


Epoch [5/20], Loss: 1.241307169884977


Epoch [6/20]: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s]


Epoch [6/20], Loss: 1.2162025648614634


Epoch [7/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [7/20], Loss: 1.1942088259455492


Epoch [8/20]: 100%|██████████| 391/391 [00:23<00:00, 16.49it/s]


Epoch [8/20], Loss: 1.1785253181177027


Epoch [9/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [9/20], Loss: 1.169309815939735


Epoch [10/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [10/20], Loss: 1.1617975613040388


Epoch [11/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [11/20], Loss: 1.156444044064378


Epoch [12/20]: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s]


Epoch [12/20], Loss: 1.1498771346438572


Epoch [13/20]: 100%|██████████| 391/391 [00:23<00:00, 16.49it/s]


Epoch [13/20], Loss: 1.1375941253074295


Epoch [14/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [14/20], Loss: 1.1314440418387313


Epoch [15/20]: 100%|██████████| 391/391 [00:23<00:00, 16.54it/s]


Epoch [15/20], Loss: 1.1368016670731937


Epoch [16/20]: 100%|██████████| 391/391 [00:23<00:00, 16.48it/s]


Epoch [16/20], Loss: 1.1236925704399947


Epoch [17/20]: 100%|██████████| 391/391 [00:23<00:00, 16.52it/s]


Epoch [17/20], Loss: 1.1276405901860094


Epoch [18/20]: 100%|██████████| 391/391 [00:23<00:00, 16.51it/s]


Epoch [18/20], Loss: 1.117490833067833


Epoch [19/20]: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s]


Epoch [19/20], Loss: 1.1184365011542046


Epoch [20/20]: 100%|██████████| 391/391 [00:23<00:00, 16.54it/s]


Epoch [20/20], Loss: 1.1145461104105197
Accuracy of the model on the test images: 61.19 %


In [13]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
#from torchlars import LARS
import multiprocessing
from tqdm import tqdm

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimCLRDataTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

class ResNetSimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.backbone = models.resnet18(weights=None)


        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()

        dim_mlp = self.backbone.fc.in_features


        self.backbone.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp),
            nn.ReLU(),
            nn.Linear(dim_mlp, out_dim)
        )

    def forward(self, x):
        return self.backbone(x)

class Classifier(nn.Module):
    def __init__(self, encoder, num_classes = 10):
        super(Classifier, self).__init__()
        self.encoder = nn.Sequential(*list(encoder.backbone.children())[:-1])
        dim_mlp = encoder.backbone.fc[0].in_features

        self.fc = nn.Linear(dim_mlp, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x).squeeze()
        logits = self.fc(features)
        return logits


def main():
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    print("Start")


    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=SimCLRDataTransform())
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)


    model = ResNetSimCLR(base_model='resnet18', out_dim=128).to(device)

    def nt_xent_loss(out_1, out_2, temperature):
        batch_size = out_1.shape[0]
        out = torch.cat([out_1, out_2], dim=0)
        sim_matrix = F.cosine_similarity(out.unsqueeze(1), out.unsqueeze(0), dim=2)

        sim_ij = torch.diag(sim_matrix, batch_size)
        sim_ji = torch.diag(sim_matrix, -batch_size)

        positives = torch.cat([sim_ij, sim_ji], dim=0)
        negatives = sim_matrix[~torch.eye(2 * batch_size, dtype=bool)].view(2 * batch_size, -1)

        logits = torch.cat((positives.unsqueeze(1), negatives), dim=1)
        labels = torch.zeros(2 * batch_size).long().to(device)

        logits = logits / temperature
        loss = F.cross_entropy(logits, labels)
        return loss

    #def train_simclr():
    temperature = 0.5
    epochs = 200
    learning_rate = 3e-4

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=1e-4, momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    #optimizer = LARS(model.parameters(), lr=4.8, weight_decay = 1e-6)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for step, ((x_i, x_j), _) in enumerate(tqdm(train_loader,desc=f"Epoch [{epoch+1}/{epochs}]")):
            optimizer.zero_grad()
            x_i, x_j = x_i.to(device), x_j.to(device)

            x = torch.cat([x_i, x_j],dim=0)
            h = model(x)
            h_i, h_j = torch.chunk(h, 2)


            loss = nt_xent_loss(h_i, h_j, temperature)


            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader)}')

    torch.save(model, 'simclr_res18_100.pth')


    for param in model.backbone.parameters():
        param.requires_grad = False


    #def train_classifier():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    classifier = Classifier(model).to(device)
    train_class = datasets.CIFAR10(root = './data', train=True, download=True, transform = transform)
    train_class_load = torch.utils.data.DataLoader(train_class, batch_size = 128, shuffle = True, num_workers=4)

    test_class = datasets.CIFAR10(root = './data', train=True, download=True, transform = transform)
    test_class_load = torch.utils.data.DataLoader(test_class, batch_size = 128, shuffle = True, num_workers=4)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.fc.parameters(), lr=3e-4)

    epochs = 20

    for epoch in range(epochs):
        classifier.train()
        running_loss = 0.0

        for step, (images, labels) in enumerate(tqdm(train_class_load,desc=f"Epoch [{epoch+1}/{epochs}]")):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()


            outputs = classifier(images)
            loss = criterion(outputs, labels)


            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader)}')


    classifier.eval()
    correct = 0
    total = 0

    torch.save(classifier, 'classifier_res18_100.pth')

    with torch.no_grad():
        for images, labels in test_class_load:
            images, labels = images.to(device), labels.to(device)

            outputs = classifier(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy} %')

if __name__ == '__main__':
    multiprocessing.freeze_support()
    main()

Using device: cuda
Start
Files already downloaded and verified


Epoch [1/200]: 100%|██████████| 391/391 [00:45<00:00,  8.60it/s]


Epoch [1/200], Loss: 4.702502332380056


Epoch [2/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [2/200], Loss: 4.4882843262704135


Epoch [3/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [3/200], Loss: 4.398215267664331


Epoch [4/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [4/200], Loss: 4.345229720825429


Epoch [5/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [5/200], Loss: 4.306858363358871


Epoch [6/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [6/200], Loss: 4.282972180325052


Epoch [7/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [7/200], Loss: 4.260695441604575


Epoch [8/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [8/200], Loss: 4.246405816139163


Epoch [9/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [9/200], Loss: 4.234199727587687


Epoch [10/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [10/200], Loss: 4.222021922431029


Epoch [11/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [11/200], Loss: 4.208750513203613


Epoch [12/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [12/200], Loss: 4.199765682220459


Epoch [13/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [13/200], Loss: 4.192166336357136


Epoch [14/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [14/200], Loss: 4.185472433524364


Epoch [15/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [15/200], Loss: 4.1770738317533525


Epoch [16/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [16/200], Loss: 4.1731793063375955


Epoch [17/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [17/200], Loss: 4.168070155336424


Epoch [18/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [18/200], Loss: 4.158481711926668


Epoch [19/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [19/200], Loss: 4.15626712223453


Epoch [20/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [20/200], Loss: 4.151009558411816


Epoch [21/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [21/200], Loss: 4.14721897495982


Epoch [22/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [22/200], Loss: 4.145617289921207


Epoch [23/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [23/200], Loss: 4.139624804182126


Epoch [24/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [24/200], Loss: 4.1361313567442055


Epoch [25/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [25/200], Loss: 4.131014625739564


Epoch [26/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [26/200], Loss: 4.130609705015217


Epoch [27/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [27/200], Loss: 4.1277157626188625


Epoch [28/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [28/200], Loss: 4.118300642198919


Epoch [29/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [29/200], Loss: 4.117807579162481


Epoch [30/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [30/200], Loss: 4.110931813564447


Epoch [31/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [31/200], Loss: 4.113654594592122


Epoch [32/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [32/200], Loss: 4.112037339783691


Epoch [33/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [33/200], Loss: 4.108255898556136


Epoch [34/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [34/200], Loss: 4.104406800721308


Epoch [35/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [35/200], Loss: 4.103442973188122


Epoch [36/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [36/200], Loss: 4.100553122322882


Epoch [37/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [37/200], Loss: 4.094579598482917


Epoch [38/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [38/200], Loss: 4.099609059750882


Epoch [39/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [39/200], Loss: 4.09496405850286


Epoch [40/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [40/200], Loss: 4.092155724230325


Epoch [41/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [41/200], Loss: 4.094032277231631


Epoch [42/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [42/200], Loss: 4.089138091982478


Epoch [43/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [43/200], Loss: 4.08274899480288


Epoch [44/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [44/200], Loss: 4.084215617850614


Epoch [45/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [45/200], Loss: 4.080189377450577


Epoch [46/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [46/200], Loss: 4.0817794525409905


Epoch [47/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [47/200], Loss: 4.079264365803555


Epoch [48/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [48/200], Loss: 4.075395409713316


Epoch [49/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [49/200], Loss: 4.07158585704501


Epoch [50/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [50/200], Loss: 4.074555623256947


Epoch [51/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [51/200], Loss: 4.071755639122575


Epoch [52/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [52/200], Loss: 4.073699929220292


Epoch [53/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [53/200], Loss: 4.068776759954974


Epoch [54/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [54/200], Loss: 4.067790916203843


Epoch [55/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [55/200], Loss: 4.067904701013394


Epoch [56/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [56/200], Loss: 4.066613846727649


Epoch [57/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [57/200], Loss: 4.063325653905454


Epoch [58/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [58/200], Loss: 4.0652364507660534


Epoch [59/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [59/200], Loss: 4.062398814789169


Epoch [60/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [60/200], Loss: 4.062865239579964


Epoch [61/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [61/200], Loss: 4.062312265186359


Epoch [62/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [62/200], Loss: 4.057071672985925


Epoch [63/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [63/200], Loss: 4.057684676421573


Epoch [64/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [64/200], Loss: 4.054126741331252


Epoch [65/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [65/200], Loss: 4.055229504090136


Epoch [66/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [66/200], Loss: 4.054447046021366


Epoch [67/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [67/200], Loss: 4.051902232572551


Epoch [68/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [68/200], Loss: 4.051151794545791


Epoch [69/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [69/200], Loss: 4.0504074761324835


Epoch [70/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [70/200], Loss: 4.052217026805634


Epoch [71/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [71/200], Loss: 4.048553300330706


Epoch [72/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [72/200], Loss: 4.047979364614657


Epoch [73/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [73/200], Loss: 4.045867397962019


Epoch [74/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [74/200], Loss: 4.04491887190153


Epoch [75/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [75/200], Loss: 4.045303631926436


Epoch [76/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [76/200], Loss: 4.043946546056996


Epoch [77/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [77/200], Loss: 4.043453235455486


Epoch [78/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [78/200], Loss: 4.039558122834891


Epoch [79/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [79/200], Loss: 4.040736326476193


Epoch [80/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [80/200], Loss: 4.039812978881095


Epoch [81/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [81/200], Loss: 4.041377648673095


Epoch [82/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [82/200], Loss: 4.036837224765202


Epoch [83/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [83/200], Loss: 4.0359122868998885


Epoch [84/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [84/200], Loss: 4.033202768896547


Epoch [85/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [85/200], Loss: 4.033064251970452


Epoch [86/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [86/200], Loss: 4.0325164264425295


Epoch [87/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [87/200], Loss: 4.034802700857372


Epoch [88/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [88/200], Loss: 4.031123188755396


Epoch [89/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [89/200], Loss: 4.0318695035432


Epoch [90/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [90/200], Loss: 4.028685275246115


Epoch [91/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [91/200], Loss: 4.02969263520692


Epoch [92/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [92/200], Loss: 4.033384599953966


Epoch [93/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [93/200], Loss: 4.02809339655025


Epoch [94/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [94/200], Loss: 4.026512844788144


Epoch [95/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [95/200], Loss: 4.027868646489995


Epoch [96/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [96/200], Loss: 4.028962810935877


Epoch [97/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [97/200], Loss: 4.0258106478035


Epoch [98/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [98/200], Loss: 4.0291564592619995


Epoch [99/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [99/200], Loss: 4.025748456530558


Epoch [100/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [100/200], Loss: 4.021074163944215


Epoch [101/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [101/200], Loss: 4.022221973180161


Epoch [102/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [102/200], Loss: 4.023427974232627


Epoch [103/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [103/200], Loss: 4.021509807128126


Epoch [104/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [104/200], Loss: 4.021642409931973


Epoch [105/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [105/200], Loss: 4.022282072345314


Epoch [106/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [106/200], Loss: 4.017345710178775


Epoch [107/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [107/200], Loss: 4.019408863828615


Epoch [108/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [108/200], Loss: 4.019784857854819


Epoch [109/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [109/200], Loss: 4.016855863658973


Epoch [110/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [110/200], Loss: 4.017184988616982


Epoch [111/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [111/200], Loss: 4.015283471178216


Epoch [112/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [112/200], Loss: 4.015565124009272


Epoch [113/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [113/200], Loss: 4.0169832828404655


Epoch [114/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [114/200], Loss: 4.0143078178396


Epoch [115/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [115/200], Loss: 4.014202180725839


Epoch [116/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [116/200], Loss: 4.011008859595375


Epoch [117/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [117/200], Loss: 4.012736554036055


Epoch [118/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [118/200], Loss: 4.0132644524049885


Epoch [119/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [119/200], Loss: 4.012746584080064


Epoch [120/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [120/200], Loss: 4.011252797778


Epoch [121/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [121/200], Loss: 4.007940471019891


Epoch [122/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [122/200], Loss: 4.010734464811242


Epoch [123/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [123/200], Loss: 4.011137523309654


Epoch [124/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [124/200], Loss: 4.006466802123867


Epoch [125/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [125/200], Loss: 4.009795018779043


Epoch [126/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [126/200], Loss: 4.009445024573284


Epoch [127/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [127/200], Loss: 4.006243846605501


Epoch [128/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [128/200], Loss: 4.007794787511801


Epoch [129/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [129/200], Loss: 4.0051113181102


Epoch [130/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [130/200], Loss: 4.006779429247922


Epoch [131/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [131/200], Loss: 4.0017611425551


Epoch [132/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [132/200], Loss: 4.005157946930517


Epoch [133/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [133/200], Loss: 4.00492889557958


Epoch [134/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [134/200], Loss: 4.004200492673518


Epoch [135/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [135/200], Loss: 4.0035094788007415


Epoch [136/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [136/200], Loss: 4.002609388297781


Epoch [137/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [137/200], Loss: 4.0004745102904335


Epoch [138/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [138/200], Loss: 4.002249721980766


Epoch [139/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [139/200], Loss: 3.999751632475792


Epoch [140/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [140/200], Loss: 4.0031121478361245


Epoch [141/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [141/200], Loss: 4.001478303119045


Epoch [142/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [142/200], Loss: 3.999293388917928


Epoch [143/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [143/200], Loss: 4.000114475064875


Epoch [144/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [144/200], Loss: 3.997424648241009


Epoch [145/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [145/200], Loss: 3.997928040106888


Epoch [146/200]: 100%|██████████| 391/391 [00:45<00:00,  8.61it/s]


Epoch [146/200], Loss: 3.997933731664477


Epoch [147/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [147/200], Loss: 3.998943185562368


Epoch [148/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [148/200], Loss: 3.99876114962351


Epoch [149/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [149/200], Loss: 3.997028399611373


Epoch [150/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [150/200], Loss: 3.9964673964263837


Epoch [151/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [151/200], Loss: 3.994473740877703


Epoch [152/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [152/200], Loss: 3.995212410417054


Epoch [153/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [153/200], Loss: 3.992420322145038


Epoch [154/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [154/200], Loss: 3.9916321938604953


Epoch [155/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [155/200], Loss: 3.9970275008160137


Epoch [156/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [156/200], Loss: 3.9937622187387607


Epoch [157/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [157/200], Loss: 3.9942739040345487


Epoch [158/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [158/200], Loss: 3.991959648059152


Epoch [159/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [159/200], Loss: 3.9914533598038853


Epoch [160/200]: 100%|██████████| 391/391 [00:45<00:00,  8.61it/s]


Epoch [160/200], Loss: 3.990925709000024


Epoch [161/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [161/200], Loss: 3.9943275915082457


Epoch [162/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [162/200], Loss: 3.9901792539659975


Epoch [163/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [163/200], Loss: 3.992980069211682


Epoch [164/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [164/200], Loss: 3.9919517168303584


Epoch [165/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [165/200], Loss: 3.9880278811735264


Epoch [166/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [166/200], Loss: 3.9910149373056942


Epoch [167/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [167/200], Loss: 3.9904745911698205


Epoch [168/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [168/200], Loss: 3.9907388516399256


Epoch [169/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [169/200], Loss: 3.986069206691459


Epoch [170/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [170/200], Loss: 3.990197604269628


Epoch [171/200]: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s]


Epoch [171/200], Loss: 3.9892242326760843


Epoch [172/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [172/200], Loss: 3.9887898462202847


Epoch [173/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [173/200], Loss: 3.9868557514132137


Epoch [174/200]: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s]


Epoch [174/200], Loss: 3.983868271493546


Epoch [175/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [175/200], Loss: 3.9872548281384246


Epoch [176/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [176/200], Loss: 3.982616594685313


Epoch [177/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [177/200], Loss: 3.984009350047392


Epoch [178/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [178/200], Loss: 3.9851729083244147


Epoch [179/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [179/200], Loss: 3.984474500426856


Epoch [180/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [180/200], Loss: 3.9854117348370957


Epoch [181/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [181/200], Loss: 3.985502499753557


Epoch [182/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [182/200], Loss: 3.9845881230386015


Epoch [183/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [183/200], Loss: 3.9820500438475546


Epoch [184/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [184/200], Loss: 3.9837796285634153


Epoch [185/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [185/200], Loss: 3.981153489988478


Epoch [186/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [186/200], Loss: 3.9837502158816207


Epoch [187/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [187/200], Loss: 3.98084098176883


Epoch [188/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [188/200], Loss: 3.979236710102052


Epoch [189/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [189/200], Loss: 3.9786273573365665


Epoch [190/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [190/200], Loss: 3.9798303827300403


Epoch [191/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [191/200], Loss: 3.981928804036601


Epoch [192/200]: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s]


Epoch [192/200], Loss: 3.9810317562669133


Epoch [193/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [193/200], Loss: 3.9836454476846757


Epoch [194/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [194/200], Loss: 3.9795216959150856


Epoch [195/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [195/200], Loss: 3.980050611983785


Epoch [196/200]: 100%|██████████| 391/391 [00:45<00:00,  8.62it/s]


Epoch [196/200], Loss: 3.9783191754080147


Epoch [197/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [197/200], Loss: 3.97839369859232


Epoch [198/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [198/200], Loss: 3.9784508507574916


Epoch [199/200]: 100%|██████████| 391/391 [00:45<00:00,  8.65it/s]


Epoch [199/200], Loss: 3.9818834727987302


Epoch [200/200]: 100%|██████████| 391/391 [00:45<00:00,  8.66it/s]


Epoch [200/200], Loss: 3.976852599312277
Files already downloaded and verified
Files already downloaded and verified


Epoch [1/20]: 100%|██████████| 391/391 [00:06<00:00, 57.20it/s]


Epoch [1/20], Loss: 0.9558241553318775


Epoch [2/20]: 100%|██████████| 391/391 [00:06<00:00, 56.19it/s]


Epoch [2/20], Loss: 0.6728371501426258


Epoch [3/20]: 100%|██████████| 391/391 [00:06<00:00, 56.57it/s]


Epoch [3/20], Loss: 0.6346103368360368


Epoch [4/20]: 100%|██████████| 391/391 [00:06<00:00, 56.63it/s]


Epoch [4/20], Loss: 0.6201018104162972


Epoch [5/20]: 100%|██████████| 391/391 [00:07<00:00, 55.75it/s]


Epoch [5/20], Loss: 0.6085138655532046


Epoch [6/20]: 100%|██████████| 391/391 [00:06<00:00, 57.36it/s]


Epoch [6/20], Loss: 0.5984554062109164


Epoch [7/20]: 100%|██████████| 391/391 [00:06<00:00, 56.58it/s]


Epoch [7/20], Loss: 0.5962074669578191


Epoch [8/20]: 100%|██████████| 391/391 [00:06<00:00, 56.36it/s]


Epoch [8/20], Loss: 0.5825565961163367


Epoch [9/20]: 100%|██████████| 391/391 [00:06<00:00, 55.95it/s]


Epoch [9/20], Loss: 0.5846339369673863


Epoch [10/20]: 100%|██████████| 391/391 [00:06<00:00, 56.37it/s]


Epoch [10/20], Loss: 0.5800871934427325


Epoch [11/20]: 100%|██████████| 391/391 [00:06<00:00, 56.52it/s]


Epoch [11/20], Loss: 0.5790243156425788


Epoch [12/20]: 100%|██████████| 391/391 [00:07<00:00, 54.24it/s]


Epoch [12/20], Loss: 0.5754834560634535


Epoch [13/20]: 100%|██████████| 391/391 [00:06<00:00, 57.67it/s]


Epoch [13/20], Loss: 0.5722946744135884


Epoch [14/20]: 100%|██████████| 391/391 [00:06<00:00, 56.82it/s]


Epoch [14/20], Loss: 0.572203705987662


Epoch [15/20]: 100%|██████████| 391/391 [00:07<00:00, 55.30it/s]


Epoch [15/20], Loss: 0.5698682039290133


Epoch [16/20]: 100%|██████████| 391/391 [00:06<00:00, 57.44it/s]


Epoch [16/20], Loss: 0.5668951237902922


Epoch [17/20]: 100%|██████████| 391/391 [00:07<00:00, 55.45it/s]


Epoch [17/20], Loss: 0.5688486778370255


Epoch [18/20]: 100%|██████████| 391/391 [00:07<00:00, 55.37it/s]


Epoch [18/20], Loss: 0.5655832057413848


Epoch [19/20]: 100%|██████████| 391/391 [00:06<00:00, 56.41it/s]


Epoch [19/20], Loss: 0.5603087848562109


Epoch [20/20]: 100%|██████████| 391/391 [00:07<00:00, 55.26it/s]

Epoch [20/20], Loss: 0.5654777619235046





Accuracy of the model on the test images: 81.128 %
