In [1]:
import os
import math
import numpy as np
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import transforms

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import style

plt.style.use('fivethirtyeight')
sns.set(style='whitegrid',color_codes=True)

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

EPOCHS = 10
BATCH_SIZE = 30
LEARNING_RATE = 0.003
IMG_SIZE = 64
CONV_SIZE = math.floor((((IMG_SIZE-2)/2)-2)/2)

TRAIN_DATA_PATH = "C:/Users/Administrator/Desktop/datasets/images/flowers_train_test/train"
TEST_DATA_PATH = "C:/Users/Administrator/Desktop/datasets/images/flowers_train_test/test"

TRANSFORM_IMG = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])])

train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=TRANSFORM_IMG)
test_data_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
train_data_loader

<torch.utils.data.dataloader.DataLoader at 0x4e2a580>

In [4]:
print(len(train_data))
print(len(test_data))
train_data.classes

3813
510


['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

In [5]:
def calc_accuracy(true, pred):
    pred = F.softmax(pred, dim=1)
    print(true.unsqueeze(1))
    true = torch.zeros(pred.shape[0], pred.shape[1]).scatter_(1, true.unsqueeze(1), 1.)
    acc = (true.argmax(-1) == pred.argmax(-1)).float().detach().numpy()
    acc = float((100 * acc.sum()) / len(acc))
    return round(acc, 4)

def plot_loss(train_loss, val_loss):
    plt.plot(train_loss, label='train loss')
    plt.plot(val_loss, label='test loss')
    plt.xlabel('Epoch')
    plt.legend()
    plt.savefig('results/plot_loss.png')
    plt.close()
    print("Loss plot saved.")

def plot_accu(train_accuracy, val_accuracy):
    plt.plot(train_accuracy, label='train accuracy')
    plt.plot(val_accuracy, label='test accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Percent')
    plt.legend()
    plt.savefig('results/plot_accu.png')
    plt.close()
    print("Accu plot saved.")

In [6]:
class Flower_Net_1(nn.Module):
    def __init__(self):
        super(Flower_Net_1,self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3,8,kernel_size=3,padding=1),nn.AvgPool2d(kernel_size=3,padding=1))
        self.layer2 = nn.Sequential(nn.Conv2d(8,16,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.layer3 = nn.Sequential(nn.Conv2d(16,8,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.flatten = nn.Flatten()

    def forward(self,x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        out = self.flatten(x)
        return out

class Flower_Net_2(nn.Module):
    def __init__(self):
        super(Flower_Net_2,self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3,16,kernel_size=3,padding=1),nn.AvgPool2d(kernel_size=3,padding=1))
        self.layer2 = nn.Sequential(nn.Conv2d(16,32,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.layer3 = nn.Sequential(nn.Conv2d(32,8,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.flatten = nn.Flatten()

    def forward(self,x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        out = self.flatten(x)
        return out
    
class Flower_Net_3(nn.Module):
    def __init__(self):
        super(Flower_Net_3,self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3,32,kernel_size=3,padding=1),nn.AvgPool2d(kernel_size=3,padding=1))
        self.layer2 = nn.Sequential(nn.Conv2d(32,8,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.layer3 = nn.Sequential(nn.Conv2d(8,8,kernel_size=3,padding=1),nn.MaxPool2d(kernel_size=3,padding=1))
        self.flatten = nn.Flatten()

    def forward(self,x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        out = self.flatten(x)
        return out

class ensemble_Net(nn.Module):    
    def __init__(self):
        super(ensemble_Net,self).__init__()
        f1 = Flower_Net_1()
        f2 = Flower_Net_2()
        f3 = Flower_Net_3()
        self.e1 = f1
        self.e2 = f2
        self.e3 = f3
        self.avgpool = nn.AvgPool1d(kernel_size=1)
        self.fc1 = nn.Linear(216,30)
        self.fc2 = nn.Linear(30,5)
    
    def forward(self,x):
        o1 = self.e1(x)
    
        o2 = self.e2(x)
        o3 = self.e3(x)
        x = torch.cat((o1,o2,o3),dim=1)
        #print(x.size())
        x = self.fc1(x)
        out = self.fc2(x)
        
        return out
    
        
model = ensemble_Net()

In [7]:
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=LEARNING_RATE, weight_decay=LEARNING_RATE)

In [8]:
import torch.nn.functional as F 

# train_data_loader, test_data_loader = data_loader(TRAIN_DATA_PATH, TEST_DATA_PATH)
model = ensemble_Net()
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=LEARNING_RATE)

train_loss = []
train_accuracy = []
val_loss = []
val_accuracy = []

for epoch in range(EPOCHS):

    start = time.time()

    train_epoch_loss = []
    train_epoch_accuracy = []
    val_epoch_loss = []
    val_epoch_accuracy = []

    for images, labels in train_data_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(images)

        acc = calc_accuracy(labels.cpu(), preds.cpu())
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        train_epoch_loss.append(loss_value)
        train_epoch_accuracy.append(acc)

    for images, labels in test_data_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        preds = model(images)
        acc = calc_accuracy(labels.cpu(), preds.cpu())
        loss = criterion(preds, labels)

        loss_value = loss.item()
        val_epoch_loss.append(loss_value)
        val_epoch_accuracy.append(acc)

    train_epoch_loss = np.mean(train_epoch_loss)
    train_epoch_accuracy = np.mean(train_epoch_accuracy)
    val_epoch_loss = np.mean(val_epoch_loss)
    val_epoch_accuracy = np.mean(val_epoch_accuracy)

    end = time.time()

    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

    print("@@ Epoch {} = {}s".format(epoch, int(end - start)))
    print("Train Loss = {}".format(round(train_epoch_loss, 3)))
    print("Train Accu = {} %".format(train_epoch_accuracy))
    print("Valid Loss = {}".format(round(val_epoch_loss, 3)))
    print("Valid Accu = {} % \n".format(val_epoch_accuracy))

plot_loss(train_loss, val_loss)
plot_accu(train_accuracy, val_accuracy)

labels
tensor([4, 3, 1, 4, 1, 4, 1, 3, 4, 2, 3, 2, 0, 4, 4, 3, 4, 1, 4, 3, 4, 4, 4, 2,
        4, 3, 2, 0, 4, 0])
torch.Size([30, 3, 64, 64])
tensor([[ 0.0304,  0.2241,  0.0523,  0.1433, -0.0783],
        [ 0.0335,  0.2483,  0.0414,  0.1478, -0.0705],
        [ 0.0145,  0.2204,  0.0370,  0.1669, -0.0753],
        [ 0.0619,  0.2565,  0.0753,  0.1835, -0.0975],
        [ 0.0606,  0.2617,  0.0330,  0.1485, -0.0498],
        [ 0.0417,  0.2508,  0.0493,  0.1644, -0.0493],
        [ 0.0422,  0.1833,  0.0098,  0.1439, -0.0710],
        [ 0.0359,  0.2384,  0.0305,  0.1429, -0.0489],
        [ 0.0281,  0.2165,  0.0494,  0.1424, -0.0586],
        [ 0.0298,  0.2144,  0.0304,  0.1457, -0.0678],
        [ 0.0199,  0.2312,  0.0111,  0.1442, -0.0588],
        [ 0.0364,  0.2199,  0.0282,  0.1685, -0.0910],
        [ 0.0442,  0.2430,  0.0332,  0.1851, -0.0683],
        [ 0.0553,  0.2130,  0.0665,  0.1911, -0.0821],
        [ 0.0472,  0.1845,  0.0565,  0.1580, -0.0732],
        [ 0.0333,  0.2434,  0.062

tensor([[0],
        [0],
        [2],
        [2],
        [4],
        [3],
        [2],
        [2],
        [2],
        [1],
        [0],
        [4],
        [0],
        [3],
        [2],
        [2],
        [3],
        [2],
        [1],
        [4],
        [4],
        [4],
        [4],
        [1],
        [4],
        [3],
        [4],
        [4],
        [4],
        [4]])
labels
tensor([1, 4, 0, 1, 2, 4, 1, 2, 4, 4, 3, 3, 0, 1, 3, 0, 4, 3, 0, 4, 4, 2, 1, 1,
        2, 4, 4, 2, 1, 1])
torch.Size([30, 3, 64, 64])
tensor([[-0.0054,  0.1538, -0.0083,  0.0841,  0.0235],
        [-0.0390,  0.1401,  0.0085,  0.1219,  0.0427],
        [-0.0133,  0.1523, -0.0011,  0.1076,  0.0308],
        [-0.0311,  0.1505, -0.0113,  0.0911,  0.0272],
        [-0.0323,  0.1476,  0.0155,  0.1160,  0.0423],
        [-0.0200,  0.1414, -0.0034,  0.1087,  0.0285],
        [-0.0432,  0.1519, -0.0041,  0.1172,  0.0328],
        [-0.0037,  0.1154,  0.0051,  0.1169,  0.0491],
        [-0.0352,  0.1389, 

tensor([[2],
        [1],
        [2],
        [1],
        [1],
        [3],
        [1],
        [2],
        [1],
        [4],
        [4],
        [0],
        [2],
        [4],
        [3],
        [1],
        [1],
        [1],
        [4],
        [0],
        [1],
        [2],
        [4],
        [2],
        [2],
        [3],
        [3],
        [4],
        [1],
        [4]])
labels
tensor([1, 3, 1, 2, 1, 2, 3, 1, 4, 4, 0, 2, 4, 1, 4, 3, 4, 1, 1, 0, 2, 1, 3, 4,
        3, 0, 2, 4, 1, 1])
torch.Size([30, 3, 64, 64])
tensor([[-0.0905,  0.1878, -0.0325, -0.0346,  0.1470],
        [-0.2351,  0.1894, -0.0237,  0.0402,  0.3261],
        [-0.1139,  0.2001, -0.0406, -0.0279,  0.2113],
        [-0.1803,  0.1936, -0.0220, -0.0165,  0.2337],
        [-0.2136,  0.2060, -0.0513,  0.0194,  0.2796],
        [-0.0736,  0.1659, -0.0276, -0.0049,  0.1143],
        [-0.1365,  0.2006, -0.0545, -0.0696,  0.2081],
        [-0.1536,  0.1813, -0.0203,  0.0369,  0.2448],
        [-0.1402,  0.1378, 

tensor([[3],
        [1],
        [1],
        [0],
        [4],
        [0],
        [0],
        [1],
        [4],
        [3],
        [2],
        [3],
        [4],
        [1],
        [1],
        [4],
        [3],
        [1],
        [0],
        [3],
        [4],
        [3],
        [1],
        [2],
        [3],
        [3],
        [3],
        [3],
        [2],
        [3]])
labels
tensor([2, 1, 1, 4, 3, 4, 1, 2, 2, 1, 3, 4, 2, 0, 4, 2, 1, 2, 1, 4, 2, 3, 1, 4,
        1, 4, 2, 2, 2, 1])
torch.Size([30, 3, 64, 64])
tensor([[-0.4526,  0.3936, -0.1070, -0.1009,  0.4367],
        [-0.3252,  0.2492, -0.0531,  0.0247,  0.2869],
        [-0.2853,  0.3554, -0.0926,  0.0338,  0.2562],
        [-0.5849,  0.2065,  0.0546,  0.2354,  0.5526],
        [-0.4823,  0.1808,  0.0645,  0.1221,  0.4165],
        [-0.4765,  0.2995, -0.0389, -0.0729,  0.4849],
        [-0.4238,  0.3685, -0.0632, -0.0954,  0.4044],
        [-0.2914,  0.1267,  0.0275,  0.0787,  0.3247],
        [-0.4677,  0.2463, 

tensor([[3],
        [1],
        [3],
        [4],
        [4],
        [4],
        [2],
        [1],
        [4],
        [1],
        [1],
        [0],
        [4],
        [1],
        [2],
        [2],
        [0],
        [4],
        [2],
        [3],
        [0],
        [2],
        [4],
        [3],
        [4],
        [2],
        [0],
        [1],
        [2],
        [1]])
labels
tensor([3, 0, 4, 3, 0, 2, 3, 4, 0, 1, 3, 0, 1, 1, 4, 0, 2, 2, 4, 4, 4, 3, 2, 2,
        0, 2, 2, 4, 4, 3])
torch.Size([30, 3, 64, 64])
tensor([[-0.5031,  0.0769,  0.2427,  0.5458,  0.3492],
        [-0.4260,  0.1850,  0.0544,  0.0604,  0.3611],
        [-0.5035, -0.0897,  0.3769,  0.1316,  0.5072],
        [-0.2429,  0.2413, -0.0322,  0.1272,  0.1703],
        [-0.4918, -0.1661,  0.4550,  0.3424,  0.4214],
        [-0.5796, -0.3272,  0.5544,  0.2577,  0.6235],
        [-0.3328,  0.2650,  0.0309,  0.2294,  0.2435],
        [-0.3514,  0.2498,  0.0029,  0.0016,  0.2604],
        [-0.3835,  0.2493, 

labels
tensor([1, 4, 2, 4, 1, 4, 0, 3, 0, 0, 1, 4, 1, 4, 4, 3, 0, 3, 1, 2, 4, 1, 1, 2,
        1, 3, 3, 4, 2, 0])
torch.Size([30, 3, 64, 64])
tensor([[-0.1372,  0.2588, -0.0210, -0.0862,  0.0739],
        [-0.9854, -1.0062,  1.1518, -0.1262,  1.0102],
        [-0.7215, -0.4601,  0.7114, -0.3630,  0.6227],
        [-0.5073, -0.0929,  0.3087,  0.1873,  0.4883],
        [-0.1695,  0.3980, -0.1306,  0.1205,  0.1194],
        [-0.2241,  0.2528,  0.0144, -0.1211,  0.1596],
        [-0.1950,  0.2949, -0.0407, -0.1715,  0.1425],
        [-0.5385, -0.3078,  0.4265,  0.0155,  0.4343],
        [-0.2681,  0.1932,  0.0914, -0.1877,  0.1666],
        [-0.2304,  0.1478,  0.1136, -0.0261,  0.2010],
        [-0.1944,  0.5435, -0.2525,  0.2703,  0.1977],
        [-0.4350, -0.1218,  0.3052, -0.0764,  0.3163],
        [-0.1406,  0.3961, -0.1190,  0.0075,  0.1000],
        [-0.2047,  0.1986,  0.0432, -0.2120,  0.1666],
        [-0.9792, -1.0116,  0.9344, -0.1836,  0.7589],
        [-0.5138,  0.2847,  0.232

tensor([0, 1, 1, 1, 2, 3, 2, 2, 4, 3, 1, 4, 3, 0, 2, 2, 2, 3, 0, 0, 0, 4, 4, 0,
        2, 2, 4, 4, 3, 1])
torch.Size([30, 3, 64, 64])
tensor([[ 0.0634,  0.5523, -0.3146, -0.3470,  0.1642],
        [ 0.0103,  0.7292, -0.4278,  0.4986,  0.0864],
        [-0.0969,  0.2534, -0.0536, -0.3124,  0.1594],
        [ 0.0476,  0.7190, -0.4236, -0.0141,  0.0683],
        [-0.5467, -0.1464,  0.0991, -0.2784,  1.0319],
        [-0.0118,  0.8279, -0.5830,  0.4281,  0.1209],
        [-1.5621, -1.5958,  1.5306, -1.1439,  1.9103],
        [-1.9112, -1.9625,  2.0464, -1.1208,  2.5406],
        [-0.0545,  0.3342, -0.1226, -0.2144,  0.0994],
        [-0.0449,  0.8598, -0.6641,  0.8983,  0.1974],
        [ 0.0331,  0.6493, -0.4279,  0.0465,  0.0511],
        [-0.5505, -0.1830,  0.2364, -0.5085,  1.1056],
        [-0.1038,  1.0703, -0.8607,  0.8668,  0.2720],
        [-1.4507, -1.5720,  1.6773, -1.4288,  1.9206],
        [-2.3103, -2.7746,  2.7914, -1.8051,  3.2385],
        [-2.0470, -2.2785,  1.8913, -1.6

tensor([2, 1, 2, 1, 4, 1, 4, 4, 4, 0, 1, 0, 3, 1, 1, 4, 2, 1, 1, 3, 0, 4, 0, 4,
        0, 2, 3, 1, 0, 3])
torch.Size([30, 3, 64, 64])
tensor([[-4.6210e-01, -8.5834e-01,  8.3541e-01, -7.4966e-01,  7.5047e-01],
        [ 3.4513e-01,  2.0241e-01, -1.2875e-01, -5.7349e-01, -1.9563e-02],
        [-1.4595e+00, -2.2988e+00,  2.3839e+00, -1.3539e+00,  1.8376e+00],
        [ 4.2547e-01,  6.3146e-01, -5.9343e-01,  4.5139e-01,  1.7230e-01],
        [ 1.2951e-01,  1.0926e+00, -9.4425e-01,  1.9431e+00,  3.4575e-01],
        [ 5.9083e-01,  1.8166e+00, -1.6841e+00,  2.3038e+00,  5.6017e-02],
        [ 2.6763e-01,  2.0563e-01, -4.9221e-02, -5.6744e-01,  5.1453e-02],
        [ 9.4749e-01,  1.6922e+00, -1.7665e+00,  1.7183e+00, -1.7955e-01],
        [-2.4209e-01, -7.0736e-01,  5.9856e-01, -9.1233e-01,  7.1802e-01],
        [ 2.8062e-01,  2.3808e-01, -1.3668e-01, -4.2098e-01, -7.6561e-03],
        [ 6.6189e-01,  1.1692e+00, -1.2864e+00,  9.7824e-01, -1.0334e-01],
        [ 3.9211e-01,  4.0622e-02, -9.05

labels
tensor([1, 1, 0, 4, 0, 4, 4, 2, 0, 4, 4, 1, 0, 3, 4, 4, 2, 4, 3, 2, 0, 2, 4, 4,
        1, 1, 2, 1, 2, 3])
torch.Size([30, 3, 64, 64])
tensor([[ 1.1114e+00,  4.2142e-01, -4.9534e-01, -1.2130e+00, -1.1551e-01],
        [ 6.9278e-01,  6.1030e-01, -1.1263e+00,  3.7193e-01,  2.9046e-01],
        [ 8.0640e-01,  1.7762e-01, -2.0377e-01, -1.2555e+00, -6.3947e-02],
        [ 6.7202e-01,  1.4252e+00, -1.6551e+00,  1.8124e+00,  2.1660e-01],
        [ 5.6759e-01, -1.2664e-01,  8.7475e-04, -9.5674e-01,  3.7134e-01],
        [-8.3242e-01, -2.6989e+00,  1.9377e+00, -2.6076e+00,  2.2906e+00],
        [ 6.7561e-01, -3.5401e-01,  1.6549e-01, -1.7148e+00,  3.0143e-01],
        [ 1.8107e-01, -9.9575e-01,  5.7188e-01, -1.6348e+00,  9.0098e-01],
        [ 2.3612e-01,  1.4924e-01, -5.8096e-01,  7.2040e-01,  6.7317e-01],
        [ 7.1722e-01,  9.4131e-01, -1.2018e+00,  1.1481e+00,  6.3259e-01],
        [ 8.1591e-01,  1.1186e+00, -1.5877e+00,  1.1711e+00,  3.7884e-01],
        [ 3.3024e-01,  1.3859e+00

tensor([[4],
        [1],
        [2],
        [3],
        [4],
        [2],
        [2],
        [2],
        [4],
        [4],
        [1],
        [2],
        [4],
        [4],
        [0],
        [4],
        [0],
        [4],
        [3],
        [1],
        [4],
        [4],
        [3],
        [4],
        [2],
        [4],
        [0],
        [2],
        [0],
        [2]])
labels
tensor([4, 3, 1, 4, 2, 2, 4, 1, 0, 4, 1, 1, 2, 0, 0, 3, 4, 0, 1, 2, 1, 4, 3, 1,
        2, 4, 0, 4, 4, 1])
torch.Size([30, 3, 64, 64])
tensor([[ 0.1882,  0.4538, -0.1883, -0.5015, -0.0074],
        [ 0.1043,  1.5569, -1.2681,  1.9205,  0.3436],
        [ 0.2960,  0.0764,  0.1356, -0.8625,  0.1236],
        [ 0.3490,  0.6373, -0.3788,  0.0095, -0.2030],
        [-0.2227, -0.9297,  1.0800, -1.5031,  0.9435],
        [ 0.3655,  0.4785, -0.2364, -0.3329,  0.0081],
        [ 0.2546,  0.1711,  0.0079, -0.6161,  0.0504],
        [ 0.9205,  2.3173, -1.9645,  1.4791, -0.0604],
        [ 0.2021,  0.0780, 

tensor([[4],
        [1],
        [4],
        [2],
        [3],
        [1],
        [3],
        [0],
        [1],
        [4],
        [3],
        [2],
        [1],
        [3],
        [0],
        [3],
        [0],
        [1],
        [2],
        [1],
        [1],
        [2],
        [3],
        [0],
        [1],
        [0],
        [4],
        [0],
        [2],
        [0]])
labels
tensor([3, 0, 0, 3, 0, 0, 4, 3, 0, 2, 0, 4, 1, 3, 1, 2, 1, 1, 2, 2, 3, 4, 1, 0,
        2, 0, 1, 3, 3, 4])
torch.Size([30, 3, 64, 64])
tensor([[ 2.6997e-01,  1.4636e+00, -9.6693e-01,  9.6073e-01, -1.7788e-01],
        [ 4.4364e-01,  6.9072e-01, -4.1991e-01, -7.5508e-03, -1.0787e-03],
        [-5.1035e-01, -8.3278e-01,  1.0121e+00, -7.6451e-01,  6.1563e-01],
        [ 3.4012e-01,  8.8375e-01, -5.3808e-01,  2.6573e-01, -6.7870e-02],
        [ 3.7442e-01,  4.0664e-01, -7.2894e-02, -4.6831e-01, -1.3583e-01],
        [ 6.5306e-01,  5.7349e-01, -8.4128e-02, -9.1900e-01, -3.6035e-01],
        [-6.6863e

tensor([[-3.5237e-01, -4.0274e-01,  6.3716e-01,  1.9073e-01,  4.5411e-01],
        [-2.3862e-02,  1.2268e+00, -9.9046e-01,  1.6195e+00, -7.7615e-03],
        [ 2.8381e-01,  9.1117e-01, -6.2655e-01,  4.7706e-01, -6.3301e-02],
        [ 1.9077e-02,  6.5694e-01, -2.3876e-01,  6.3888e-01,  5.6403e-02],
        [-2.5118e-01, -9.6961e-01,  1.1789e+00, -8.6859e-01,  7.2891e-01],
        [ 1.3461e-01, -4.7475e-02,  3.0547e-01, -5.8564e-01,  1.0261e-01],
        [ 7.0394e-01,  9.4618e-01, -3.8164e-01, -5.1890e-01, -4.7825e-01],
        [-4.6088e-01,  1.1516e+00, -7.0365e-01,  2.0583e+00,  4.7067e-01],
        [-1.2339e+00, -8.1030e-01,  1.1153e+00,  1.3668e+00,  1.3147e+00],
        [ 3.3895e-01,  1.4917e+00, -7.3929e-01,  7.9860e-01, -2.7592e-01],
        [ 8.5747e-02, -9.7971e-02,  3.0097e-01, -6.2030e-01,  1.1624e-01],
        [ 7.0405e-01,  8.4760e-01, -2.4790e-01, -8.2478e-01, -4.6886e-01],
        [ 2.0657e-01,  8.1352e-01, -4.1383e-01,  3.1362e-01, -1.1842e-01],
        [-2.1336e-01, -6.

tensor([[3],
        [1],
        [3],
        [4],
        [1],
        [3],
        [1],
        [2],
        [3],
        [0],
        [4],
        [4],
        [2],
        [3],
        [3],
        [4],
        [2],
        [4],
        [0],
        [4],
        [2],
        [0],
        [1],
        [4],
        [4],
        [2],
        [4],
        [4],
        [3],
        [0]])
labels
tensor([1, 2, 2, 1, 0, 1, 0, 0, 4, 4, 4, 4, 1, 1, 0, 0, 4, 0, 1, 1, 4, 3, 2, 2,
        3, 4, 1, 3, 0, 1])
torch.Size([30, 3, 64, 64])
tensor([[-6.6854e-02,  7.0872e-01, -3.5233e-01,  5.9644e-01,  2.4476e-01],
        [-2.6357e-02, -4.7231e-01,  7.6404e-01, -1.4267e+00,  7.8812e-01],
        [ 2.8297e-01,  3.5921e-01, -2.2361e-01, -6.1524e-02,  2.6794e-02],
        [ 4.0430e-01, -1.4241e-02,  3.2887e-01, -1.1069e+00, -5.5416e-03],
        [ 4.6663e-01,  4.7828e-01,  5.9542e-03, -6.9458e-01, -1.4371e-01],
        [-4.8592e-01,  4.6766e-01, -2.3345e-01,  1.2460e+00,  3.6832e-01],
        [ 1.6787e

tensor([[1],
        [4],
        [1],
        [0],
        [0],
        [1],
        [3],
        [0],
        [2],
        [2],
        [4],
        [2],
        [3],
        [1],
        [3],
        [1],
        [1],
        [4],
        [3],
        [0],
        [1],
        [1],
        [4],
        [4],
        [1],
        [2],
        [0],
        [2],
        [4],
        [3]])
labels
tensor([3, 2, 3, 3, 3, 4, 1, 2, 1, 3, 4, 2, 3, 4, 0, 1, 1, 4, 4, 4, 1, 0, 0, 1,
        0, 1, 3, 3, 0, 4])
torch.Size([30, 3, 64, 64])
tensor([[ 0.1741,  1.0450, -0.9321,  0.8406,  0.3280],
        [-3.1172, -6.0873,  5.1357, -0.8688,  4.9271],
        [ 0.1718,  0.5583, -0.2341, -0.1300,  0.4836],
        [-0.9975,  1.0754, -1.1629,  2.9417,  1.0685],
        [-2.2137, -0.5525, -0.0524,  3.7590,  2.7422],
        [-0.1289, -0.6966,  0.7501, -0.9667,  0.9586],
        [ 0.5399,  1.2572, -1.1629,  0.3384,  0.3607],
        [-2.3050, -1.9097,  1.6422,  1.9377,  3.2238],
        [-1.0135,  1.6995, 

tensor([[-3.7903e+00, -5.6663e+00,  4.3918e+00,  7.2614e-01,  6.3456e+00],
        [-8.7860e-02,  1.5571e+00, -1.3141e+00,  1.1401e+00,  2.3055e-01],
        [-7.1574e-01, -1.3838e+00,  7.6483e-01, -3.4827e-01,  1.9336e+00],
        [-3.4865e+00, -7.7793e+00,  6.2205e+00, -2.0398e+00,  6.5748e+00],
        [-2.2827e-01,  2.3419e+00, -2.0372e+00,  2.6711e+00,  4.2027e-01],
        [-8.1374e-01,  3.0262e+00, -2.7809e+00,  3.8629e+00,  7.9667e-01],
        [ 1.9281e+00,  1.5542e+00, -4.4381e-01, -2.9086e+00, -5.3369e-01],
        [-9.5983e-02,  1.4478e+00, -1.1883e+00,  1.2839e+00,  6.4343e-01],
        [-3.8681e-01,  1.2142e+00, -1.0574e+00,  1.3426e+00,  7.5422e-01],
        [-6.7832e-02,  1.0641e+00, -6.1369e-01,  5.7797e-01,  8.3867e-01],
        [ 6.0566e-01, -6.3750e-01,  8.1522e-01, -2.8167e+00,  1.1439e+00],
        [-5.9569e-01,  5.7841e-01, -6.1744e-01,  1.5428e+00,  1.1588e+00],
        [-3.0632e+00, -5.4537e+00,  3.9571e+00, -3.8174e-01,  4.6711e+00],
        [-2.1098e+00, -4.

tensor([[4],
        [4],
        [3],
        [1],
        [1],
        [3],
        [4],
        [4],
        [4],
        [4],
        [0],
        [0],
        [1],
        [3],
        [1],
        [0],
        [2],
        [1],
        [0],
        [4],
        [3],
        [1],
        [1],
        [1],
        [4],
        [3],
        [0],
        [1],
        [1],
        [1]])
labels
tensor([0, 1, 4, 1, 0, 1, 4, 0, 4, 3, 0, 0, 1, 2, 1, 4, 4, 1, 1, 1, 3, 1, 0, 4,
        4, 3, 2, 1, 2, 4])
torch.Size([30, 3, 64, 64])
tensor([[ 1.3108e+00,  1.2145e+00, -7.7664e-01, -1.2158e+00, -3.2764e-01],
        [ 1.2070e+00,  4.2354e-01, -2.4505e-01, -1.5515e+00, -1.9855e-01],
        [ 1.1033e+00,  4.4972e-01, -7.4919e-01, -9.1138e-01,  2.8582e-01],
        [ 1.8358e+00,  4.5701e-01, -2.5364e-01, -2.3308e+00, -3.3265e-01],
        [-3.2312e+00, -4.3505e+00,  3.5963e+00,  3.8658e-01,  4.9451e+00],
        [ 1.3695e+00,  1.4314e+00, -6.8151e-01, -1.5218e+00, -4.2953e-01],
        [-1.7897e

tensor([[1],
        [2],
        [2],
        [1],
        [0],
        [2],
        [4],
        [0],
        [3],
        [4],
        [1],
        [4],
        [1],
        [0],
        [3],
        [4],
        [0],
        [3],
        [2],
        [3],
        [3],
        [1],
        [2],
        [2],
        [2],
        [1],
        [0],
        [3],
        [0],
        [3]])
labels
tensor([1, 3, 3, 3, 4, 0, 0, 3, 0, 3, 4, 1, 4, 1, 4, 2, 2, 1, 3, 3, 1, 0, 2, 2,
        4, 1, 3, 4, 2, 2])
torch.Size([30, 3, 64, 64])
tensor([[ 1.3256e+00, -1.3069e-01,  5.0888e-01, -2.5335e+00, -2.1877e-01],
        [ 1.3637e-01,  7.2042e-01, -1.0942e+00,  1.2634e+00,  5.2669e-01],
        [-1.7660e-01,  1.2026e+00, -1.4809e+00,  2.0196e+00,  5.5815e-01],
        [ 3.2722e-01,  5.5858e-01, -7.7172e-01,  6.0589e-01,  3.2916e-01],
        [ 1.5300e-01, -7.4105e-01,  3.5963e-01, -6.0646e-01,  1.0875e+00],
        [ 9.8421e-01,  2.9006e-01, -2.4645e-01, -8.9313e-01,  6.8448e-02],
        [ 4.5568e

tensor([[-3.5366e-02,  1.4441e+00, -2.0988e+00,  3.4898e+00,  4.4389e-01],
        [ 4.2062e-01,  1.0270e+00, -8.1039e-01,  6.1758e-01, -6.4015e-02],
        [-1.2532e-01, -4.8537e-01,  3.6687e-01, -1.6958e-01,  9.6855e-01],
        [ 2.3029e-01,  8.0970e-01, -7.2621e-01,  6.9641e-01,  8.4402e-02],
        [ 4.1526e-01,  4.2168e-01, -2.7246e-01,  1.4085e-03, -4.2120e-02],
        [-1.1791e-01, -1.1906e+00,  9.1534e-01, -1.5148e+00,  1.0806e+00],
        [ 2.4001e-01,  3.0688e-01, -2.3173e-01,  2.3717e-01,  1.7093e-01],
        [ 4.3986e-01,  7.3430e-02,  1.2622e-01, -7.0677e-01, -4.1913e-02],
        [ 3.3126e-01,  6.7866e-01, -7.1723e-01,  5.6023e-01,  3.3668e-01],
        [ 7.8167e-01,  4.2447e-01, -4.0136e-02, -1.0833e+00, -2.7186e-01],
        [ 1.7697e-01,  3.6208e-01, -2.4832e-01,  2.6118e-01,  9.3088e-02],
        [ 4.7909e-01,  4.3891e-01, -6.0934e-01,  4.7621e-01,  2.9304e-02],
        [ 2.4530e-01, -6.8619e-01,  7.9072e-01, -1.4215e+00,  7.2550e-01],
        [-2.0558e+00, -4.

tensor([[0],
        [1],
        [0],
        [1],
        [3],
        [0],
        [0],
        [0],
        [2],
        [1],
        [3],
        [1],
        [1],
        [4],
        [1],
        [2],
        [1],
        [1],
        [1],
        [4],
        [1],
        [1],
        [0],
        [2],
        [2],
        [3],
        [2],
        [3],
        [4],
        [2]])
labels
tensor([4, 0, 3, 1, 0, 1, 1, 2, 2, 0, 2, 0, 1, 2, 1, 4, 0, 0, 1, 4, 4, 1, 2, 1,
        2, 3, 4, 2, 2, 2])
torch.Size([30, 3, 64, 64])
tensor([[ 1.6920e-01,  7.4201e-01, -5.9265e-01,  8.3310e-01,  1.1668e-01],
        [ 8.9843e-02, -1.0294e+00,  1.4451e+00, -2.1133e+00,  6.7056e-01],
        [ 7.8900e-03, -4.0709e-01,  4.8893e-01, -5.1071e-01,  4.8726e-01],
        [ 1.1914e+00,  7.7991e-01,  1.3581e-01, -2.2053e+00, -6.7672e-01],
        [ 9.5885e-01,  4.7373e-01,  1.1411e-01, -1.2839e+00, -2.5996e-01],
        [ 2.4003e-01, -3.3460e-02,  4.5785e-01, -1.2420e+00,  1.6008e-01],
        [ 9.3076e

tensor([[-0.4730, -0.6656,  0.5106,  0.7033,  1.1781],
        [ 0.2156,  0.1442,  0.1323, -0.5142,  0.2322],
        [ 1.1467,  2.0846, -0.7486, -0.6983, -1.0248],
        [-3.7421, -7.2182,  6.0661, -3.1329,  6.0597],
        [ 0.4369,  0.3629,  0.0575, -0.8035, -0.0859],
        [ 1.2489,  1.4365, -0.1677, -1.9836, -0.9399],
        [ 0.4436,  1.0674, -0.8922,  1.0456, -0.2731],
        [-1.2624, -3.0487,  3.1395, -3.6180,  2.5384],
        [ 0.4409,  1.1627, -0.6442,  0.3106, -0.4234],
        [-2.6513, -4.2260,  3.0839,  0.2853,  4.5300],
        [ 0.6577,  0.8489, -0.1223, -0.9969, -0.5062],
        [ 1.8919,  2.1935, -0.2399, -3.1823, -1.5460],
        [ 0.6619,  2.4726, -2.0186,  2.3774, -0.6716],
        [ 0.8447,  1.2970, -0.2760, -1.1801, -0.6873],
        [-1.4045, -2.7025,  2.7093, -2.3870,  2.7920],
        [ 0.2830,  0.7581, -0.1112, -0.8370,  0.1734],
        [ 1.5547,  2.1399, -0.4479, -2.1182, -1.3247],
        [-3.2845, -5.1939,  4.1566, -0.8656,  5.3889],
        [ 

labels
tensor([3, 3, 1, 3, 4, 2, 2, 2, 4, 1, 1, 1, 4, 3, 1, 3, 1, 4, 3, 2, 0, 3, 1, 3,
        0, 4, 4, 1, 0, 3])
torch.Size([30, 3, 64, 64])
tensor([[ 0.3291,  2.0458, -2.0771,  3.0565, -0.4084],
        [-2.6596, -3.3178,  1.3428,  3.6615,  3.6979],
        [ 1.4566,  1.6934, -0.2217, -2.5325, -1.0903],
        [ 0.7155,  1.7625, -1.0375,  0.3730, -0.7111],
        [-0.7379,  0.0337, -0.2738,  1.8327,  1.3328],
        [-0.5834, -1.6006,  1.8083, -1.6527,  1.4485],
        [-2.8943, -5.4542,  4.4516, -1.9147,  5.0807],
        [-1.1867, -2.6686,  2.9105, -2.3900,  2.7643],
        [-1.3037, -2.1272,  1.3467,  0.0674,  2.6059],
        [ 0.6243,  1.7775, -2.1095,  2.9510, -0.4071],
        [ 0.4214,  2.6091, -2.5198,  3.5285, -0.6124],
        [ 0.6026,  1.9518, -1.2779,  0.8589, -0.5681],
        [ 0.3254,  0.1463, -0.3718,  1.1594,  0.4412],
        [ 0.1905,  2.7825, -3.2235,  5.4362, -0.3055],
        [ 1.3714,  2.1894, -1.3573, -0.0574, -1.1609],
        [-1.5404, -0.0420, -1.238

tensor([0, 4, 3, 1, 2, 4, 0, 3, 4, 2, 3, 2, 1, 0, 4, 3, 1, 4, 3, 0, 1, 4, 2, 1,
        2, 2, 1, 2, 2, 2])
torch.Size([30, 3, 64, 64])
tensor([[-3.2468, -6.0420,  4.5705, -1.8685,  4.3617],
        [-2.9258, -5.1268,  4.0830, -1.4479,  5.3899],
        [-0.2809,  0.6764, -1.0035,  2.3085,  1.2187],
        [ 0.1163,  1.2798, -0.9192,  1.1315,  0.0336],
        [ 1.8268,  1.8963, -0.3585, -3.1249, -1.3263],
        [-1.1457, -1.3751,  0.9008,  1.0322,  1.7980],
        [ 0.0984,  1.0531, -1.1528,  1.6791, -0.0957],
        [ 0.3667,  2.0811, -1.8000,  2.3915, -0.3287],
        [ 0.8586,  1.0962, -0.1983, -1.1704, -0.6440],
        [-0.6878, -2.1702,  2.2257, -2.6234,  2.4323],
        [ 0.3615,  0.9256, -0.6068,  0.2564, -0.3169],
        [-1.1041, -2.5995,  2.4014, -2.1466,  2.3352],
        [ 0.5948,  1.0513, -0.2972, -0.7628, -0.5713],
        [ 0.4755,  1.4946, -1.1767,  1.1517, -0.4001],
        [ 2.5056,  3.5658, -1.0529, -3.3197, -2.4534],
        [ 0.1781,  1.7206, -1.5597,  2.2

labels
tensor([4, 1, 3, 4, 3, 1, 2, 1, 0, 3, 1, 1, 4, 4, 4, 3, 1, 3, 3, 1, 0, 3, 0, 1,
        4, 0, 4, 1, 1, 4])
torch.Size([30, 3, 64, 64])
tensor([[-2.1989e+00, -4.3378e+00,  4.6352e+00, -3.3298e+00,  2.9797e+00],
        [ 1.0252e+00,  1.2156e+00, -1.7419e-01, -1.7643e+00, -7.4440e-01],
        [-2.2832e+00, -3.4425e+00,  2.7395e+00, -1.8455e-01,  3.2180e+00],
        [ 2.8621e-01, -5.3114e-01,  9.8699e-01, -2.6033e+00,  4.8643e-01],
        [ 1.6230e-01,  6.3300e-01, -9.1698e-02, -1.8788e-01,  3.3488e-04],
        [ 5.8519e-01,  1.0187e+00, -2.2518e-01, -1.0252e+00, -4.5323e-01],
        [-2.6514e+00, -5.6443e+00,  5.1585e+00, -4.1862e+00,  3.4517e+00],
        [ 1.0529e+00,  1.2376e+00, -1.7105e-01, -1.8128e+00, -8.9181e-01],
        [ 1.1535e+00,  4.9168e-01,  1.6546e-01, -2.3371e+00, -5.9170e-01],
        [ 1.0386e-01,  8.5678e-01, -6.4460e-01,  6.3664e-01,  5.9888e-02],
        [-2.3716e-01,  4.6608e-01, -5.6337e-01,  1.1008e+00,  5.1228e-01],
        [ 1.1062e-01, -1.2360e-01

labels
tensor([2, 2, 1, 3, 4, 1, 0, 4, 0, 2, 0, 0, 0, 2, 2, 4, 0, 1, 1, 1, 0, 2, 3, 4,
        0, 2, 0, 1, 1, 3])
torch.Size([30, 3, 64, 64])
tensor([[-7.7256e-02,  1.2136e-02,  2.2259e-01,  1.9483e-01,  1.2149e-01],
        [-8.7432e-01, -1.5255e+00,  1.1097e+00,  6.5021e-01,  1.4781e+00],
        [ 2.9762e-01,  5.0314e-01, -1.6493e-01, -2.1122e-01, -1.6783e-01],
        [ 2.1514e-01,  1.4958e+00, -1.1036e+00,  1.0027e+00, -2.9454e-01],
        [-9.1794e-01, -1.2796e+00,  1.3156e+00, -9.3924e-01,  1.3638e+00],
        [ 2.9101e-02,  9.8465e-01, -1.0421e+00,  1.2368e+00,  9.8590e-02],
        [ 4.3390e-01,  9.2423e-01, -3.1893e-01, -6.9098e-01, -3.1930e-01],
        [ 3.5207e-01,  6.4044e-01, -3.0524e-01, -2.1690e-01, -3.4708e-01],
        [ 1.1948e-01,  6.3304e-01, -4.2142e-01,  5.8823e-01, -4.3199e-02],
        [-1.7550e+00, -2.7432e+00,  2.3316e+00,  7.2455e-01,  2.4706e+00],
        [ 7.4052e-01,  1.0247e+00, -2.1288e-01, -1.4328e+00, -5.6264e-01],
        [-7.2529e-02,  6.4350e-01

tensor([4, 4, 1, 1, 2, 3, 0, 0, 1, 4, 1, 1, 1, 1, 1, 2, 1, 3, 0, 2, 1, 3, 1, 4,
        3, 3, 1, 1, 3, 3])
torch.Size([30, 3, 64, 64])
tensor([[ 0.1633,  0.6200, -0.1789, -0.2311, -0.0862],
        [-0.2408, -0.5617,  0.7482, -0.6291,  0.3758],
        [ 0.6521,  1.0365, -0.6243, -0.2778, -0.4918],
        [ 0.2352,  1.0956, -0.9161,  1.3032, -0.1884],
        [-1.6031, -2.6414,  2.7354, -1.0631,  1.8280],
        [ 0.8160,  1.3487, -1.0977,  0.8120, -0.7712],
        [ 0.4322,  1.1323, -0.6217, -0.0881, -0.3279],
        [ 0.8646,  1.2712, -0.7355, -0.6892, -0.7258],
        [ 0.4673,  2.6442, -2.1188,  2.0202, -0.6545],
        [-1.7957, -3.0000,  2.3840, -0.3533,  2.5592],
        [ 0.4781,  1.0092, -0.5875,  0.2358, -0.5843],
        [-0.0087,  0.0951,  0.3229, -0.8406,  0.1066],
        [ 0.1852,  0.7622, -0.5021,  0.4293, -0.2741],
        [ 0.7407,  1.6129, -1.0477, -0.2934, -0.6397],
        [-0.4271, -0.5609,  0.7051, -0.2795,  0.5226],
        [-1.6838, -2.5668,  2.5254, -1.1

tensor([1, 1, 4, 4, 2, 3, 0, 3, 3, 3, 3, 1, 2, 1, 3, 1, 2, 4, 2, 3, 0, 1, 1, 1,
        0, 3, 0, 1, 0, 2])
torch.Size([30, 3, 64, 64])
tensor([[ 0.4259,  0.8589, -0.4432, -0.2995, -0.2331],
        [ 0.5128,  1.1555, -0.6513, -0.0480, -0.3559],
        [-0.7566, -0.8542,  1.0546, -1.5521,  1.0106],
        [-3.1019, -4.6460,  5.0014, -4.0192,  3.7457],
        [-0.0499,  0.0760,  0.6898, -1.6812,  0.2597],
        [ 0.3999,  1.1256, -1.0527,  1.0060, -0.3054],
        [ 0.5506,  0.6952, -0.2086, -0.9769, -0.1144],
        [ 0.2136,  1.2533, -0.9694,  1.0937, -0.2089],
        [ 0.1994,  0.8312, -0.3790, -0.0943,  0.0182],
        [ 0.0545,  0.2947,  0.2761, -1.5234,  0.5152],
        [ 0.6142,  1.3703, -1.3963,  0.8541, -0.2044],
        [ 0.5836,  1.2113, -0.7577, -0.2492, -0.3782],
        [-2.2130, -2.8075,  2.9284, -1.4247,  2.2840],
        [ 0.1540,  0.8660, -0.1732, -0.6391, -0.1031],
        [ 0.2228,  1.1369, -1.2295,  1.9310, -0.1730],
        [ 0.1925,  1.0991, -0.6406,  0.6

KeyboardInterrupt: 

In [None]:
torch.save(CNNmodel.state_dict(), 'cnn_model.pt')

In [None]:
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.legend()

In [None]:
plt.plot(train_correct, label='train correct')
plt.plot(test_correct, label='test correct')
plt.legend()