In [1]:
import torch
import torch.nn as nn
import argparse
import os
from torchsummary import summary
from torchvision import datasets, transforms
from kymatio.torch import Scattering2D
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

train_transforms = [transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                   ]

In [3]:
train_set = datasets.CIFAR10(root=".data", train=True,
                                     transform=transforms.Compose(train_transforms),
                                     download=True)
test_set = datasets.CIFAR10(root=".data", train=False,
                                    transform=transforms.Compose(
                                        [transforms.ToTensor(), normalize]
                                    ))

Files already downloaded and verified


In [4]:
class ScatterLinear(nn.Module):
    def __init__(self, in_channels, hw_dims, input_norm=None, classes=10, **kwargs):
        super(ScatterLinear, self).__init__()
        self.Scattering2D = Scattering2D(J=2, shape=(32, 32))
        self.Scattering2D = self.Scattering2D.cuda()
        self.K = in_channels
        self.h = hw_dims[0]
        self.w = hw_dims[1]
        self.fc = None
        self.norm = None
        self.build(input_norm, classes=classes, **kwargs)

    def build(self, input_norm=None, num_groups=None, bn_stats=None, clip_norm=None, classes=10):
        self.fc = nn.Linear(self.K * self.h * self.w, classes)

        if input_norm is None:
            self.norm = nn.Identity()
        elif input_norm == "BatchNorm":
            self.norm = nn.BatchNorm2d(num_features=self.K,affine=False)

    def forward(self, x):
        x = self.Scattering2D(x)
        x = x.view(-1, 3 * 81*8*8)
        torch.flatten(x, start_dim=1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

In [5]:
model = ScatterLinear(in_channels=243,
                      hw_dims=[8,8],
                      input_norm='BatchNorm',
                      classes=10,
                      clip_norm=None)

In [6]:
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=512,
                                           shuffle=True,
                                           num_workers=6, 
                                           drop_last=True)
test_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=1024,
                                           shuffle=False,
                                           num_workers=6, 
                                           drop_last=False)

In [7]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
                                    momentum=0.9,
                                    nesterov=False)
scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

In [8]:
def train(model, train_loader, optimizer):
    model = model.to(device)
    model.train()
    num_examples = 0
    correct = 0
    train_loss = 0
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        output = model(data)
        #print(data.shape,output.shape, target.shape)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
        train_loss += F.cross_entropy(output, target, reduction='sum').item()
        num_examples += len(data)

    train_loss /= num_examples
    train_acc = 100. * correct / num_examples
    print(f'Train set: Average loss: {train_loss:.4f}, '
          f'Accuracy: {correct}/{num_examples} ({train_acc:.2f}%)')
    return train_loss, train_acc

In [9]:
def test(model, test_loader):
    device = next(model.parameters()).device
    
    model.eval()
    num_examples = 0
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            num_examples += len(data)

    test_loss /= num_examples
    test_acc = 100. * correct / num_examples

    print(f'Test set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{num_examples} ({test_acc:.2f}%)')

    return test_loss, test_acc

In [None]:
Epochs = 100
train_loss_lst=[]
train_acc_lst=[]
val_loss_lst=[]
val_acc_lst=[]
for epoch in range(0, Epochs):
    print(f"\nEpoch: {epoch}")
    train_loss, train_acc = train(model, train_loader, optimizer)
    test_loss, test_acc = test(model, test_loader)
    train_loss_lst.append(train_loss)
    train_acc_lst.append(train_acc)
    val_loss_lst.append(test_loss)
    val_acc_lst.append(test_acc)
    print(train_loss_lst)
    print(train_acc_lst)
    print(val_loss_lst)
    print(val_acc_lst)
    scheduler.step()


Epoch: 0


97it [00:13,  7.20it/s]


Train set: Average loss: 1.7494, Accuracy: 18760/49664 (37.77%)


100%|██████████| 49/49 [00:13<00:00,  3.54it/s]

Test set: Average loss: 1.5840, Accuracy: 21958/50000 (43.92%)
[1.7493649168112844]
[37.77384020618557]
[1.5840436083984375]
[43.916]

Epoch: 1



97it [00:13,  7.19it/s]


Train set: Average loss: 1.5285, Accuracy: 22867/49664 (46.04%)


100%|██████████| 49/49 [00:13<00:00,  3.56it/s]

Test set: Average loss: 1.4828, Accuracy: 24065/50000 (48.13%)
[1.7493649168112844, 1.5285044429228478]
[37.77384020618557, 46.04341172680412]
[1.5840436083984375, 1.4828163818359374]
[43.916, 48.13]

Epoch: 2



97it [00:13,  7.18it/s]


Train set: Average loss: 1.4676, Accuracy: 24084/49664 (48.49%)


100%|██████████| 49/49 [00:13<00:00,  3.55it/s]

Test set: Average loss: 1.4785, Accuracy: 24390/50000 (48.78%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146]
[37.77384020618557, 46.04341172680412, 48.49387886597938]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187]
[43.916, 48.13, 48.78]

Epoch: 3



97it [00:13,  7.02it/s]


Train set: Average loss: 1.4330, Accuracy: 24666/49664 (49.67%)


100%|██████████| 49/49 [00:14<00:00,  3.42it/s]

Test set: Average loss: 1.4022, Accuracy: 25175/50000 (50.35%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375]
[43.916, 48.13, 48.78, 50.35]

Epoch: 4



97it [00:14,  6.92it/s]


Train set: Average loss: 1.3974, Accuracy: 25443/49664 (51.23%)


100%|██████████| 49/49 [00:14<00:00,  3.36it/s]

Test set: Average loss: 1.3671, Accuracy: 26335/50000 (52.67%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437]
[43.916, 48.13, 48.78, 50.35, 52.67]

Epoch: 5



97it [00:13,  6.93it/s]


Train set: Average loss: 1.3838, Accuracy: 25718/49664 (51.78%)


100%|██████████| 49/49 [00:14<00:00,  3.39it/s]

Test set: Average loss: 1.3714, Accuracy: 26144/50000 (52.29%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288]

Epoch: 6



97it [00:14,  6.91it/s]


Train set: Average loss: 1.3590, Accuracy: 26266/49664 (52.89%)


100%|██████████| 49/49 [00:14<00:00,  3.38it/s]

Test set: Average loss: 1.3291, Accuracy: 26887/50000 (53.77%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774]

Epoch: 7



97it [00:13,  6.93it/s]


Train set: Average loss: 1.3544, Accuracy: 26253/49664 (52.86%)


100%|██████████| 49/49 [00:14<00:00,  3.43it/s]

Test set: Average loss: 1.3307, Accuracy: 26698/50000 (53.40%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396]

Epoch: 8



97it [00:14,  6.87it/s]


Train set: Average loss: 1.3355, Accuracy: 26592/49664 (53.54%)


100%|██████████| 49/49 [00:14<00:00,  3.34it/s]

Test set: Average loss: 1.3525, Accuracy: 26276/50000 (52.55%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552]

Epoch: 9



97it [00:14,  6.92it/s]


Train set: Average loss: 1.3249, Accuracy: 26756/49664 (53.87%)


100%|██████████| 49/49 [00:14<00:00,  3.38it/s]

Test set: Average loss: 1.2928, Accuracy: 27749/50000 (55.50%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552, 55.498]

Epoch: 10



97it [00:14,  6.87it/s]


Train set: Average loss: 1.3167, Accuracy: 26955/49664 (54.27%)


100%|██████████| 49/49 [00:14<00:00,  3.45it/s]

Test set: Average loss: 1.3173, Accuracy: 26984/50000 (53.97%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552, 55.498, 53.968]

Epoch: 11



97it [00:13,  6.94it/s]


Train set: Average loss: 1.3182, Accuracy: 26934/49664 (54.23%)


100%|██████████| 49/49 [00:14<00:00,  3.40it/s]

Test set: Average loss: 1.3098, Accuracy: 27278/50000 (54.56%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552, 55.498, 53.968, 54.556]

Epoch: 12



97it [00:13,  6.93it/s]


Train set: Average loss: 1.3053, Accuracy: 27062/49664 (54.49%)


100%|██████████| 49/49 [00:14<00:00,  3.41it/s]

Test set: Average loss: 1.2900, Accuracy: 27761/50000 (55.52%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124, 1.2899841040039062]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552, 55.498, 53.968, 54.556, 55.522]

Epoch: 13



97it [00:13,  6.96it/s]


Train set: Average loss: 1.3051, Accuracy: 27150/49664 (54.67%)


100%|██████████| 49/49 [00:14<00:00,  3.42it/s]

Test set: Average loss: 1.2820, Accuracy: 27709/50000 (55.42%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124, 1.2899841040039062, 1.2819846118164062]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774, 53.396, 52.552, 55.498, 53.968, 54.556, 55.522, 55.418]




97it [00:13,  6.93it/s]


Train set: Average loss: 1.3004, Accuracy: 27232/49664 (54.83%)


100%|██████████| 49/49 [00:14<00:00,  3.38it/s]

Test set: Average loss: 1.2953, Accuracy: 27788/50000 (55.58%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124, 1.2899841040039062, 1.2819846118164062, 1.2952664184570313]
[43.916, 48.13, 48.78, 50.35, 52.67, 52.288, 53.774


97it [00:13,  6.95it/s]


Train set: Average loss: 1.2851, Accuracy: 27555/49664 (55.48%)


100%|██████████| 49/49 [00:14<00:00,  3.44it/s]

Test set: Average loss: 1.2667, Accuracy: 28188/50000 (56.38%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124, 1.2899841040039062, 1.2819846118164062, 1.2952664184570313, 1.266665925


97it [00:13,  6.95it/s]


Train set: Average loss: 1.2920, Accuracy: 27464/49664 (55.30%)


100%|██████████| 49/49 [00:14<00:00,  3.41it/s]

Test set: Average loss: 1.2735, Accuracy: 27949/50000 (55.90%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.3097660595703124, 1.2899841040039062, 1.28198461181


97it [00:13,  6.99it/s]


Train set: Average loss: 1.2957, Accuracy: 27345/49664 (55.06%)


100%|██████████| 49/49 [00:14<00:00,  3.45it/s]

Test set: Average loss: 1.2525, Accuracy: 28457/50000 (56.91%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007813, 1.317270166015625, 1.309766059570


97it [00:13,  6.96it/s]


Train set: Average loss: 1.2717, Accuracy: 27904/49664 (56.19%)


100%|██████████| 49/49 [00:14<00:00,  3.39it/s]

Test set: Average loss: 1.2570, Accuracy: 28482/50000 (56.96%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101562, 1.3524957641601563, 1.2928422583007


97it [00:13,  6.96it/s]


Train set: Average loss: 1.2727, Accuracy: 27820/49664 (56.02%)


100%|██████████| 49/49 [00:14<00:00,  3.45it/s]

Test set: Average loss: 1.2445, Accuracy: 28515/50000 (57.03%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570313, 1.329057490234375, 1.3306577954101


97it [00:13,  6.99it/s]


Train set: Average loss: 1.2761, Accuracy: 27803/49664 (55.98%)


100%|██████████| 49/49 [00:14<00:00,  3.40it/s]

Test set: Average loss: 1.2708, Accuracy: 27919/50000 (55.84%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.402203349609375, 1.3671045092773437, 1.3713851684570


97it [00:13,  6.97it/s]


Train set: Average loss: 1.2695, Accuracy: 27796/49664 (55.97%)


100%|██████████| 49/49 [00:14<00:00,  3.45it/s]

Test set: Average loss: 1.2947, Accuracy: 27716/50000 (55.43%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794, 55.96810567010309]
[1.5840436083984375, 1.4828163818359374, 1.4785439916992187, 1.40220334960937


97it [00:13,  6.97it/s]


Train set: Average loss: 1.2737, Accuracy: 27787/49664 (55.95%)


100%|██████████| 49/49 [00:14<00:00,  3.45it/s]

Test set: Average loss: 1.2806, Accuracy: 27706/50000 (55.41%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794, 55.96810567010309, 55.94998389175258]
[1.5840436083984375, 1.482816381835937


97it [00:13,  6.96it/s]


Train set: Average loss: 1.2635, Accuracy: 28013/49664 (56.41%)


100%|██████████| 49/49 [00:14<00:00,  3.41it/s]

Test set: Average loss: 1.2816, Accuracy: 27623/50000 (55.25%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794, 55.96810567010309, 55.94998389175258, 56.405041881443296


97it [00:13,  6.98it/s]


Train set: Average loss: 1.2586, Accuracy: 28106/49664 (56.59%)


100%|██████████| 49/49 [00:14<00:00,  3.39it/s]

Test set: Average loss: 1.2476, Accuracy: 28384/50000 (56.77%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794, 55.96810567010309, 55.94998389175258


97it [00:13,  6.95it/s]


Train set: Average loss: 1.2616, Accuracy: 27993/49664 (56.36%)


100%|██████████| 49/49 [00:13<00:00,  3.54it/s]

Test set: Average loss: 1.2679, Accuracy: 27896/50000 (55.79%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.98220038659794, 55.9681056701030


97it [00:13,  7.17it/s]


Train set: Average loss: 1.2609, Accuracy: 28077/49664 (56.53%)


100%|██████████| 49/49 [00:13<00:00,  3.56it/s]

Test set: Average loss: 1.2330, Accuracy: 28699/50000 (57.40%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371136, 55.982200386597


97it [00:13,  7.18it/s]


Train set: Average loss: 1.2486, Accuracy: 28314/49664 (57.01%)


100%|██████████| 49/49 [00:13<00:00,  3.59it/s]

Test set: Average loss: 1.2385, Accuracy: 28796/50000 (57.59%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030928, 56.016430412371


97it [00:13,  7.20it/s]


Train set: Average loss: 1.2453, Accuracy: 28327/49664 (57.04%)


100%|██████████| 49/49 [00:13<00:00,  3.59it/s]

Test set: Average loss: 1.2208, Accuracy: 29122/50000 (58.24%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164948, 56.18556701030


97it [00:13,  7.16it/s]


Train set: Average loss: 1.2386, Accuracy: 28510/49664 (57.41%)


100%|██████████| 49/49 [00:13<00:00,  3.59it/s]

Test set: Average loss: 1.2310, Accuracy: 28775/50000 (57.55%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206186, 55.06000322164


97it [00:13,  7.19it/s]


Train set: Average loss: 1.1829, Accuracy: 29569/49664 (59.54%)


100%|██████████| 49/49 [00:13<00:00,  3.57it/s]

Test set: Average loss: 1.1707, Accuracy: 30108/50000 (60.22%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973, 1.182949243132601]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.482844716494846, 55.29961340206


97it [00:13,  7.17it/s]


Train set: Average loss: 1.1712, Accuracy: 29909/49664 (60.22%)


100%|██████████| 49/49 [00:14<00:00,  3.50it/s]

Test set: Average loss: 1.1662, Accuracy: 30170/50000 (60.34%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973, 1.182949243132601, 1.1711935751216929]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.83247422680412, 55.48284471649


97it [00:13,  7.21it/s]


Train set: Average loss: 1.1705, Accuracy: 29975/49664 (60.36%)


100%|██████████| 49/49 [00:13<00:00,  3.58it/s]

Test set: Average loss: 1.1757, Accuracy: 30018/50000 (60.04%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973, 1.182949243132601, 1.1711935751216929, 1.1705198300253485]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.66736469072165, 54.8324742268


97it [00:13,  7.18it/s]


Train set: Average loss: 1.1710, Accuracy: 29983/49664 (60.37%)


100%|██████████| 49/49 [00:13<00:00,  3.59it/s]

Test set: Average loss: 1.1642, Accuracy: 30071/50000 (60.14%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973, 1.182949243132601, 1.1711935751216929, 1.1705198300253485, 1.1709989702578671]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396907217, 54.667364690


97it [00:13,  7.18it/s]


Train set: Average loss: 1.1702, Accuracy: 29944/49664 (60.29%)


100%|██████████| 49/49 [00:13<00:00,  3.59it/s]

Test set: Average loss: 1.1705, Accuracy: 30107/50000 (60.21%)
[1.7493649168112844, 1.5285044429228478, 1.4675812794990146, 1.4329862373391378, 1.3974224513339013, 1.3838434108753794, 1.3590353874816108, 1.3544005942098873, 1.3355169259395796, 1.3249074429580845, 1.316713673552287, 1.3182192672159254, 1.3052511104603404, 1.3050669505424106, 1.3003858106652486, 1.2850659090219085, 1.291957246888544, 1.2956515771826518, 1.271739296077453, 1.2726994752883911, 1.2761381508148824, 1.269474671059048, 1.2737478066965477, 1.2635394125869595, 1.2585786015717024, 1.2615706994361484, 1.2608905330146711, 1.2485677003860474, 1.2452633528365302, 1.238563591671973, 1.182949243132601, 1.1711935751216929, 1.1705198300253485, 1.1709989702578671, 1.1701623557769145]
[37.77384020618557, 46.04341172680412, 48.49387886597938, 49.66575386597938, 51.230267396907216, 51.78398840206186, 52.88740335051546, 52.86122744845361, 53.54381443298969, 53.87403350515464, 54.274726159793815, 54.23244201030928, 54.49017396


97it [00:13,  7.17it/s]


Train set: Average loss: 1.1694, Accuracy: 30013/49664 (60.43%)


 84%|████████▎ | 41/49 [00:11<00:02,  3.76it/s]