In [254]:
from functools import lru_cache
from typing import Union

import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas import Series

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

from einops import rearrange, repeat
from transformers import AutoTokenizer

In [255]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop):
        super(MultiHeadAttention, self).__init__()

        self.nhead = nhead

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.c_proj = nn.Linear(embed_dim, embed_dim)
        self.drop = nn.Dropout(pdrop)

    def forward(self, x, attn_mask=None):
        _, t, _ = x.shape

        qkv = self.qkv(x).chunk(3, dim=-1)

        q, k, v = map(
            lambda tsn: rearrange(tsn, 'b t (nh hd) -> b nh t hd', nh=self.nhead), qkv
        )

        wei = q @ k.transpose(-1, -2) * k.size(dim=-1) ** -0.5
        if attn_mask is not None:
            attn_mask = attn_mask.to(x.device)
            wei = wei.masked_fill(attn_mask[:t, :t] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        attn = wei @ v
        attn = rearrange(attn, 'b nh t hd -> b t (nh hd)')

        return self.drop(self.c_proj(attn))

In [256]:
class FFN(nn.Module):
    def __init__(self, embed_dim, pdrop):
        super(FFN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(pdrop),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self,
                 # Common
                 embed_dim,
                 # MHA
                 nhead,
                 attn_pdrop,
                 # FFN
                 ffn_pdrop
                 ):
        super(Block, self).__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, nhead, attn_pdrop)
        self.ffn = FFN(embed_dim, ffn_pdrop)

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
        x = x + self.ffn(self.ln2(x))
        return x

In [257]:
class TransformerLanguageModel(nn.Module):
    def __init__(self,
                 # LM
                 vocab_size,
                 ctx_len,
                 # Common
                 embed_dim,
                 # Block
                 nlayer,
                 nhead,
                 attn_mask=None,
                 attn_pdrop=0.2,
                 ffn_pdrop=0.2,
                 # Output
                 output_dim: int=None
                 ):
        super(TransformerLanguageModel, self).__init__()
        # Properties
        self.ctx_len = ctx_len
        if attn_mask is None:
            attn_mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.attn_mask = attn_mask
        self.understanding = output_dim is not None

        # Comps
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(ctx_len, embed_dim)

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.lm_head = nn.Linear(embed_dim, vocab_size if output_dim is None else output_dim)

    def generate(self, idx, attn_mask=None, max_new_tokens=100, temp=1.0, topk=None):
        if self.understanding:
            raise NotImplementedError('Not for generating')
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.ctx_len:]
            logits = self(idx_cond, attn_mask)
            logits_of_last_tick = logits[:, -1, :] / temp
            if topk is not None:
                vals, _ = torch.topk(logits_of_last_tick, topk)
                logits_of_last_tick[logits_of_last_tick < vals[-1]] = float('-inf')
            if temp == 0:
                idx_next = torch.argmax(logits_of_last_tick, dim=-1, keepdim=True)
            else:
                probs = logits_of_last_tick.softmax(dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1) # concat on 'time' dimension
        return idx

    def forward(self, idx, attn_mask=None):
        attn_mask = attn_mask if attn_mask is not None else self.attn_mask

        _, t = idx.shape

        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb(
            torch.arange(0, t, dtype=torch.long, device=idx.device)
        )

        x = tok_emb + pos_emb

        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x, attn_mask=attn_mask)
        x = self.ln_post(x)

        logits = self.lm_head(x)

        if self.understanding:
            logits = logits[torch.arange(logits.shape[0]), idx.argmax(dim=-1)]

        return logits

In [258]:
def pair(sz: Union[int, tuple]):
    if isinstance(sz, int):
        return sz, sz
    else:
        return sz

class VisionTransformer(nn.Module):
    def __init__(self,
                 # Vision
                 img_size: Union[int, tuple],
                 patch_num: Union[int, tuple],
                 use_cls_tok: bool,
                 # Transformer
                 nlayer: int,
                 embed_dim: int,
                 nhead: int,
                 attn_pdrop: float,
                 ffn_pdrop: float,
                 # Output
                 num_classes: int
                 ):
        super(VisionTransformer, self).__init__()
        # Properties
        img_height, img_width = pair(img_size)
        patch_height_n, patch_width_n = pair(patch_num)
        scale = embed_dim ** -0.5

        self.patch_height, self.patch_width = img_height // patch_height_n, img_width // patch_width_n
        self.use_cls_tok = use_cls_tok

        # Comps
        self.to_patch = nn.Conv2d(
            in_channels=3,
            out_channels=embed_dim,
            kernel_size=(patch_height_n, patch_width_n),
            stride=(patch_height_n, patch_width_n),
        )

        self.pos_emb = nn.Parameter(scale * torch.randn(1, self.patch_height * self.patch_width + 1, embed_dim))
        self.cls_emb = nn.Parameter(scale *  torch.randn(embed_dim))

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.proj = nn.Parameter(scale * torch.randn(embed_dim, num_classes))

    def forward(self, x):
        patches = self.to_patch(x)

        idx = rearrange(
            patches, 'b c (ph h) (pw w) -> b (ph pw) (c h w)', ph=self.patch_height, pw=self.patch_width
        )


        idx = torch.cat(
            (
                idx,
                repeat(self.cls_emb, 'e -> b 1 e', b=idx.shape[0])
            ), dim=1
        )
        x = idx + self.pos_emb


        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x)

        if self.use_cls_tok:
            x = x[:, 0, :]
        else:
            x = x[:, 1:, :].mean(dim=1)

        x = self.ln_post(x)

        return x @ self.proj

In [259]:
class CLIP(nn.Module):
    def __init__(self,
                 img_encoder,
                 text_encoder,
                 ):
        super(CLIP, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.img_encoder = img_encoder
        self.text_encoder = text_encoder

    def encode_img(self, img):
        return self.img_encoder(img)

    def encode_text(self, text):
        return self.text_encoder(text)

    def forward(self, imgs, texts):
        img_features = self.encode_img(imgs)
        text_features = self.encode_text(texts)

        if img_features.shape != text_features.shape:
            raise ValueError(f'img_features shape {img_features.shape} != text_features shape {text_features.shape}')

        img_features = img_features / img_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        logits_per_img = img_features @ text_features.t() * self.logit_scale.exp()
        logits_per_text = logits_per_img.t()

        return logits_per_img, logits_per_text

In [260]:
ctx = Series()

ctx.train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

ctx.test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

train_cifar10_ds = datasets.CIFAR10(root='~/data', train=True, transform=ctx.train_transform)
test_cifar10_ds = datasets.CIFAR10(root='~/data', train=True, transform=ctx.test_transform)

In [261]:
ctx.device = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx.tokenizer = AutoTokenizer.from_pretrained('gpt2')
ctx.tokenizer.pad_token = ctx.tokenizer.eos_token
ctx.label2cls_dict = train_cifar10_ds.classes
ctx.ctx_len = 5
ctx.batch_size = 512

def label2sentence(label, prefix='The photo of a '):
    return prefix + ctx.label2cls_dict[label]

def tokenize(s):
    return ctx.tokenizer(
        s,
        max_length=ctx.ctx_len,
        padding='max_length',
        return_tensors='pt',
    )['input_ids'].squeeze(dim=0)

In [262]:
class TextImgDataset(Dataset):
    def __init__(self, img_label_ds):
        self.img_label_ds = img_label_ds

    def __len__(self):
        return len(self.img_label_ds)

    def __getitem__(self, index):
        orig_img, orig_label = self.img_label_ds[index]
        sentence = label2sentence(orig_label)
        return orig_img, tokenize(sentence)

train_dl = DataLoader(
    TextImgDataset(train_cifar10_ds),
    batch_size=ctx.batch_size,
    shuffle=True,
)

test_dl = DataLoader(
    TextImgDataset(test_cifar10_ds),
    batch_size=ctx.batch_size,
    shuffle=False,
)

for imgs, texts in train_dl:
    print(imgs.shape, texts.shape)
    break

torch.Size([512, 3, 32, 32]) torch.Size([512, 5])


In [263]:
def clip_loss(logits_per_image, logits_per_text):
    b = logits_per_text.shape[0]
    labels = torch.arange(0, b, device=logits_per_text.device)
    image_loss = F.cross_entropy(logits_per_image, labels)
    text_loss = F.cross_entropy(logits_per_text, labels)
    return (image_loss + text_loss) / 2

ctx.epochs = 256
ctx.lr = 1e-4
ctx.weight_decay = 1e-5
ctx.eval_interval = 8
ctx.saved_w = 'clip.pt'
ctx.T_max = 32

model = CLIP(
    img_encoder=VisionTransformer(
        img_size=32,
        patch_num=8,
        use_cls_tok=False,
        nlayer=4,
        embed_dim=768,
        nhead=12,
        attn_pdrop=0.2,
        ffn_pdrop=0.2,
        num_classes=128,
    ),
    text_encoder=TransformerLanguageModel(
        vocab_size=ctx.tokenizer.vocab_size,
        ctx_len=ctx.ctx_len,
        embed_dim=768,
        nlayer=4,
        nhead=12,
        attn_mask=None, # Default: Casual Attention Mask
        attn_pdrop=0.2,
        ffn_pdrop=0.2,
        output_dim=128
    )
).to(ctx.device)

if os.path.isfile(ctx.saved_w):
    model.load_state_dict(torch.load(ctx.saved_w, weights_only=True))

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=ctx.lr,
    weight_decay=ctx.weight_decay,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, ctx.T_max
)

@lru_cache
def tokenized_sentences():
    labels = range(len(ctx.label2cls_dict))
    return torch.stack([tokenize(label2sentence(label)) for label in labels])

@torch.no_grad()
def eval_accuracy():
    def eq_cnt(l1, l2):
        n = len(l1)
        cnt = 0
        for i in range(n):
            if l1[i] == l2[i]:
                cnt += 1
        return cnt
    model.eval()
    text_features = model.encode_text(
        tokenized_sentences().to(ctx.device)
    )
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    total = 0
    correct = 0
    pbar = tqdm(test_dl)
    pbar.set_description_str('Evaling accuracy on test set')
    for imgs, texts in pbar:
        imgs = imgs.to(ctx.device)
        texts = texts.to(ctx.device)
        image_features = model.encode_img(imgs)
        logits_per_image = image_features @ text_features.t()
        labels = logits_per_image.argmax(dim=1)
        pred = [label2sentence(label) for label in labels]
        gt = [ctx.tokenizer.decode(text, skipskip_special_tokens=True) for text in texts]
        total += len(imgs)
        correct += eq_cnt(pred, gt)
    return correct / total

len_train_dl = len(train_dl)
for epoch in range(1, ctx.epochs + 1):
    train_avg_loss = 0.
    pbar = tqdm(train_dl)
    pbar.set_description_str(f'Training one epoch {epoch}/{ctx.epochs}')
    for imgs, texts in pbar:
        imgs = imgs.to(ctx.device)
        texts = texts.to(ctx.device)
        optimizer.zero_grad()
        logits_per_image, logits_per_text = model(imgs, texts)
        loss = clip_loss(logits_per_image, logits_per_text)
        train_avg_loss += loss.item() / len_train_dl
        loss.backward()
        optimizer.step()
    if epoch % ctx.eval_interval == 0:
        print(f'{epoch}/{ctx.epochs}, train_avg_loss: {train_avg_loss:.3f}, test_acc: {eval_accuracy()}')
    else:
        print(f'{epoch}/{ctx.epochs}, train_avg_loss: {train_avg_loss:.3f}')

Training one epoch 1/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]


1/256, train_avg_loss: 6.194


Training one epoch 2/256: 100%|██████████| 98/98 [00:32<00:00,  3.02it/s]


2/256, train_avg_loss: 5.972


Training one epoch 3/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]


3/256, train_avg_loss: 5.846


Training one epoch 4/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]


4/256, train_avg_loss: 5.748


Training one epoch 5/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]


5/256, train_avg_loss: 5.664


Training one epoch 6/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]


6/256, train_avg_loss: 5.581


Training one epoch 7/256: 100%|██████████| 98/98 [00:32<00:00,  3.02it/s]


7/256, train_avg_loss: 5.524


Training one epoch 8/256: 100%|██████████| 98/98 [00:32<00:00,  3.04it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


8/256, train_avg_loss: 5.476, test_acc: 0.48132


Training one epoch 9/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


9/256, train_avg_loss: 5.423


Training one epoch 10/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


10/256, train_avg_loss: 5.369


Training one epoch 11/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


11/256, train_avg_loss: 5.334


Training one epoch 12/256: 100%|██████████| 98/98 [00:32<00:00,  3.05it/s]


12/256, train_avg_loss: 5.305


Training one epoch 13/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


13/256, train_avg_loss: 5.271


Training one epoch 14/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


14/256, train_avg_loss: 5.245


Training one epoch 15/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


15/256, train_avg_loss: 5.221


Training one epoch 16/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.53it/s]


16/256, train_avg_loss: 5.196, test_acc: 0.56344


Training one epoch 17/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


17/256, train_avg_loss: 5.175


Training one epoch 18/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


18/256, train_avg_loss: 5.146


Training one epoch 19/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


19/256, train_avg_loss: 5.134


Training one epoch 20/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


20/256, train_avg_loss: 5.104


Training one epoch 21/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


21/256, train_avg_loss: 5.091


Training one epoch 22/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


22/256, train_avg_loss: 5.076


Training one epoch 23/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


23/256, train_avg_loss: 5.055


Training one epoch 24/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.57it/s]


24/256, train_avg_loss: 5.045, test_acc: 0.6131


Training one epoch 25/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


25/256, train_avg_loss: 5.024


Training one epoch 26/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


26/256, train_avg_loss: 5.013


Training one epoch 27/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


27/256, train_avg_loss: 5.001


Training one epoch 28/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


28/256, train_avg_loss: 4.989


Training one epoch 29/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


29/256, train_avg_loss: 4.973


Training one epoch 30/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


30/256, train_avg_loss: 4.962


Training one epoch 31/256: 100%|██████████| 98/98 [00:32<00:00,  3.05it/s]


31/256, train_avg_loss: 4.953


Training one epoch 32/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


32/256, train_avg_loss: 4.941, test_acc: 0.6478


Training one epoch 33/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


33/256, train_avg_loss: 4.926


Training one epoch 34/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


34/256, train_avg_loss: 4.922


Training one epoch 35/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


35/256, train_avg_loss: 4.902


Training one epoch 36/256: 100%|██████████| 98/98 [00:32<00:00,  3.05it/s]


36/256, train_avg_loss: 4.895


Training one epoch 37/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


37/256, train_avg_loss: 4.881


Training one epoch 38/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


38/256, train_avg_loss: 4.878


Training one epoch 39/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


39/256, train_avg_loss: 4.869


Training one epoch 40/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.54it/s]


40/256, train_avg_loss: 4.854, test_acc: 0.6695


Training one epoch 41/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


41/256, train_avg_loss: 4.845


Training one epoch 42/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


42/256, train_avg_loss: 4.841


Training one epoch 43/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


43/256, train_avg_loss: 4.829


Training one epoch 44/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


44/256, train_avg_loss: 4.819


Training one epoch 45/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


45/256, train_avg_loss: 4.814


Training one epoch 46/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


46/256, train_avg_loss: 4.806


Training one epoch 47/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


47/256, train_avg_loss: 4.795


Training one epoch 48/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.57it/s]


48/256, train_avg_loss: 4.790, test_acc: 0.69254


Training one epoch 49/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


49/256, train_avg_loss: 4.780


Training one epoch 50/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


50/256, train_avg_loss: 4.765


Training one epoch 51/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


51/256, train_avg_loss: 4.772


Training one epoch 52/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


52/256, train_avg_loss: 4.761


Training one epoch 53/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


53/256, train_avg_loss: 4.749


Training one epoch 54/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


54/256, train_avg_loss: 4.742


Training one epoch 55/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


55/256, train_avg_loss: 4.740


Training one epoch 56/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


56/256, train_avg_loss: 4.723, test_acc: 0.70988


Training one epoch 57/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


57/256, train_avg_loss: 4.722


Training one epoch 58/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


58/256, train_avg_loss: 4.717


Training one epoch 59/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


59/256, train_avg_loss: 4.703


Training one epoch 60/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


60/256, train_avg_loss: 4.704


Training one epoch 61/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


61/256, train_avg_loss: 4.692


Training one epoch 62/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


62/256, train_avg_loss: 4.686


Training one epoch 63/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


63/256, train_avg_loss: 4.680


Training one epoch 64/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.52it/s]


64/256, train_avg_loss: 4.671, test_acc: 0.72536


Training one epoch 65/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


65/256, train_avg_loss: 4.666


Training one epoch 66/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


66/256, train_avg_loss: 4.660


Training one epoch 67/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


67/256, train_avg_loss: 4.656


Training one epoch 68/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


68/256, train_avg_loss: 4.649


Training one epoch 69/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


69/256, train_avg_loss: 4.644


Training one epoch 70/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


70/256, train_avg_loss: 4.637


Training one epoch 71/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


71/256, train_avg_loss: 4.632


Training one epoch 72/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


72/256, train_avg_loss: 4.628, test_acc: 0.73794


Training one epoch 73/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


73/256, train_avg_loss: 4.626


Training one epoch 74/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


74/256, train_avg_loss: 4.616


Training one epoch 75/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


75/256, train_avg_loss: 4.613


Training one epoch 76/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


76/256, train_avg_loss: 4.607


Training one epoch 77/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


77/256, train_avg_loss: 4.599


Training one epoch 78/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


78/256, train_avg_loss: 4.595


Training one epoch 79/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


79/256, train_avg_loss: 4.593


Training one epoch 80/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


80/256, train_avg_loss: 4.586, test_acc: 0.75274


Training one epoch 81/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


81/256, train_avg_loss: 4.580


Training one epoch 82/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


82/256, train_avg_loss: 4.575


Training one epoch 83/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


83/256, train_avg_loss: 4.575


Training one epoch 84/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


84/256, train_avg_loss: 4.563


Training one epoch 85/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


85/256, train_avg_loss: 4.562


Training one epoch 86/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


86/256, train_avg_loss: 4.554


Training one epoch 87/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


87/256, train_avg_loss: 4.551


Training one epoch 88/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.54it/s]


88/256, train_avg_loss: 4.542, test_acc: 0.76002


Training one epoch 89/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


89/256, train_avg_loss: 4.541


Training one epoch 90/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


90/256, train_avg_loss: 4.543


Training one epoch 91/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


91/256, train_avg_loss: 4.544


Training one epoch 92/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


92/256, train_avg_loss: 4.534


Training one epoch 93/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


93/256, train_avg_loss: 4.528


Training one epoch 94/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


94/256, train_avg_loss: 4.525


Training one epoch 95/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


95/256, train_avg_loss: 4.520


Training one epoch 96/256: 100%|██████████| 98/98 [00:35<00:00,  2.79it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.57it/s]


96/256, train_avg_loss: 4.520, test_acc: 0.76658


Training one epoch 97/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


97/256, train_avg_loss: 4.515


Training one epoch 98/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


98/256, train_avg_loss: 4.510


Training one epoch 99/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


99/256, train_avg_loss: 4.509


Training one epoch 100/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


100/256, train_avg_loss: 4.500


Training one epoch 101/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


101/256, train_avg_loss: 4.502


Training one epoch 102/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


102/256, train_avg_loss: 4.498


Training one epoch 103/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


103/256, train_avg_loss: 4.493


Training one epoch 104/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


104/256, train_avg_loss: 4.491, test_acc: 0.77286


Training one epoch 105/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


105/256, train_avg_loss: 4.485


Training one epoch 106/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


106/256, train_avg_loss: 4.490


Training one epoch 107/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


107/256, train_avg_loss: 4.489


Training one epoch 108/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


108/256, train_avg_loss: 4.478


Training one epoch 109/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


109/256, train_avg_loss: 4.475


Training one epoch 110/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


110/256, train_avg_loss: 4.476


Training one epoch 111/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


111/256, train_avg_loss: 4.475


Training one epoch 112/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.54it/s]


112/256, train_avg_loss: 4.475, test_acc: 0.7785


Training one epoch 113/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


113/256, train_avg_loss: 4.465


Training one epoch 114/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


114/256, train_avg_loss: 4.465


Training one epoch 115/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


115/256, train_avg_loss: 4.463


Training one epoch 116/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


116/256, train_avg_loss: 4.465


Training one epoch 117/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


117/256, train_avg_loss: 4.457


Training one epoch 118/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


118/256, train_avg_loss: 4.460


Training one epoch 119/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


119/256, train_avg_loss: 4.456


Training one epoch 120/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.60it/s]


120/256, train_avg_loss: 4.452, test_acc: 0.78102


Training one epoch 121/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


121/256, train_avg_loss: 4.443


Training one epoch 122/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


122/256, train_avg_loss: 4.454


Training one epoch 123/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


123/256, train_avg_loss: 4.445


Training one epoch 124/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


124/256, train_avg_loss: 4.449


Training one epoch 125/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


125/256, train_avg_loss: 4.441


Training one epoch 126/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


126/256, train_avg_loss: 4.441


Training one epoch 127/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


127/256, train_avg_loss: 4.436


Training one epoch 128/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.57it/s]


128/256, train_avg_loss: 4.438, test_acc: 0.78534


Training one epoch 129/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


129/256, train_avg_loss: 4.434


Training one epoch 130/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


130/256, train_avg_loss: 4.439


Training one epoch 131/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


131/256, train_avg_loss: 4.430


Training one epoch 132/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


132/256, train_avg_loss: 4.432


Training one epoch 133/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


133/256, train_avg_loss: 4.427


Training one epoch 134/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


134/256, train_avg_loss: 4.429


Training one epoch 135/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


135/256, train_avg_loss: 4.430


Training one epoch 136/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.55it/s]


136/256, train_avg_loss: 4.425, test_acc: 0.78774


Training one epoch 137/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


137/256, train_avg_loss: 4.423


Training one epoch 138/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


138/256, train_avg_loss: 4.420


Training one epoch 139/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


139/256, train_avg_loss: 4.415


Training one epoch 140/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


140/256, train_avg_loss: 4.419


Training one epoch 141/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


141/256, train_avg_loss: 4.418


Training one epoch 142/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


142/256, train_avg_loss: 4.418


Training one epoch 143/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


143/256, train_avg_loss: 4.414


Training one epoch 144/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:22<00:00,  4.41it/s]


144/256, train_avg_loss: 4.413, test_acc: 0.78868


Training one epoch 145/256: 100%|██████████| 98/98 [00:33<00:00,  2.89it/s]


145/256, train_avg_loss: 4.415


Training one epoch 146/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


146/256, train_avg_loss: 4.406


Training one epoch 147/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


147/256, train_avg_loss: 4.405


Training one epoch 148/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


148/256, train_avg_loss: 4.409


Training one epoch 149/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


149/256, train_avg_loss: 4.404


Training one epoch 150/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


150/256, train_avg_loss: 4.405


Training one epoch 151/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


151/256, train_avg_loss: 4.409


Training one epoch 152/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


152/256, train_avg_loss: 4.405, test_acc: 0.79138


Training one epoch 153/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


153/256, train_avg_loss: 4.403


Training one epoch 154/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


154/256, train_avg_loss: 4.407


Training one epoch 155/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


155/256, train_avg_loss: 4.397


Training one epoch 156/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


156/256, train_avg_loss: 4.402


Training one epoch 157/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


157/256, train_avg_loss: 4.402


Training one epoch 158/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


158/256, train_avg_loss: 4.402


Training one epoch 159/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


159/256, train_avg_loss: 4.397


Training one epoch 160/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.55it/s]


160/256, train_avg_loss: 4.394, test_acc: 0.79122


Training one epoch 161/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


161/256, train_avg_loss: 4.392


Training one epoch 162/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


162/256, train_avg_loss: 4.394


Training one epoch 163/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


163/256, train_avg_loss: 4.394


Training one epoch 164/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


164/256, train_avg_loss: 4.396


Training one epoch 165/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


165/256, train_avg_loss: 4.392


Training one epoch 166/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


166/256, train_avg_loss: 4.389


Training one epoch 167/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


167/256, train_avg_loss: 4.389


Training one epoch 168/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.58it/s]


168/256, train_avg_loss: 4.388, test_acc: 0.79262


Training one epoch 169/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


169/256, train_avg_loss: 4.385


Training one epoch 170/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


170/256, train_avg_loss: 4.385


Training one epoch 171/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


171/256, train_avg_loss: 4.388


Training one epoch 172/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


172/256, train_avg_loss: 4.386


Training one epoch 173/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


173/256, train_avg_loss: 4.384


Training one epoch 174/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


174/256, train_avg_loss: 4.387


Training one epoch 175/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


175/256, train_avg_loss: 4.384


Training one epoch 176/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.60it/s]


176/256, train_avg_loss: 4.382, test_acc: 0.79242


Training one epoch 177/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


177/256, train_avg_loss: 4.379


Training one epoch 178/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


178/256, train_avg_loss: 4.382


Training one epoch 179/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


179/256, train_avg_loss: 4.378


Training one epoch 180/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


180/256, train_avg_loss: 4.381


Training one epoch 181/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


181/256, train_avg_loss: 4.378


Training one epoch 182/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


182/256, train_avg_loss: 4.376


Training one epoch 183/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


183/256, train_avg_loss: 4.383


Training one epoch 184/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.54it/s]


184/256, train_avg_loss: 4.376, test_acc: 0.79262


Training one epoch 185/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


185/256, train_avg_loss: 4.376


Training one epoch 186/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


186/256, train_avg_loss: 4.372


Training one epoch 187/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


187/256, train_avg_loss: 4.377


Training one epoch 188/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


188/256, train_avg_loss: 4.372


Training one epoch 189/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


189/256, train_avg_loss: 4.376


Training one epoch 190/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


190/256, train_avg_loss: 4.372


Training one epoch 191/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


191/256, train_avg_loss: 4.378


Training one epoch 192/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.60it/s]


192/256, train_avg_loss: 4.371, test_acc: 0.79396


Training one epoch 193/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


193/256, train_avg_loss: 4.372


Training one epoch 194/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


194/256, train_avg_loss: 4.369


Training one epoch 195/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


195/256, train_avg_loss: 4.370


Training one epoch 196/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


196/256, train_avg_loss: 4.369


Training one epoch 197/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


197/256, train_avg_loss: 4.367


Training one epoch 198/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


198/256, train_avg_loss: 4.370


Training one epoch 199/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


199/256, train_avg_loss: 4.366


Training one epoch 200/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


200/256, train_avg_loss: 4.370, test_acc: 0.79508


Training one epoch 201/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


201/256, train_avg_loss: 4.365


Training one epoch 202/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


202/256, train_avg_loss: 4.368


Training one epoch 203/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


203/256, train_avg_loss: 4.367


Training one epoch 204/256: 100%|██████████| 98/98 [00:32<00:00,  3.05it/s]


204/256, train_avg_loss: 4.365


Training one epoch 205/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


205/256, train_avg_loss: 4.363


Training one epoch 206/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


206/256, train_avg_loss: 4.366


Training one epoch 207/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


207/256, train_avg_loss: 4.365


Training one epoch 208/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.54it/s]


208/256, train_avg_loss: 4.364, test_acc: 0.7964


Training one epoch 209/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


209/256, train_avg_loss: 4.364


Training one epoch 210/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


210/256, train_avg_loss: 4.363


Training one epoch 211/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


211/256, train_avg_loss: 4.362


Training one epoch 212/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


212/256, train_avg_loss: 4.363


Training one epoch 213/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


213/256, train_avg_loss: 4.359


Training one epoch 214/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


214/256, train_avg_loss: 4.359


Training one epoch 215/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


215/256, train_avg_loss: 4.359


Training one epoch 216/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


216/256, train_avg_loss: 4.358, test_acc: 0.79582


Training one epoch 217/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


217/256, train_avg_loss: 4.360


Training one epoch 218/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


218/256, train_avg_loss: 4.359


Training one epoch 219/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


219/256, train_avg_loss: 4.360


Training one epoch 220/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


220/256, train_avg_loss: 4.354


Training one epoch 221/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


221/256, train_avg_loss: 4.356


Training one epoch 222/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


222/256, train_avg_loss: 4.356


Training one epoch 223/256: 100%|██████████| 98/98 [00:31<00:00,  3.06it/s]


223/256, train_avg_loss: 4.353


Training one epoch 224/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


224/256, train_avg_loss: 4.354, test_acc: 0.7966


Training one epoch 225/256: 100%|██████████| 98/98 [00:31<00:00,  3.09it/s]


225/256, train_avg_loss: 4.357


Training one epoch 226/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


226/256, train_avg_loss: 4.353


Training one epoch 227/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


227/256, train_avg_loss: 4.354


Training one epoch 228/256: 100%|██████████| 98/98 [00:32<00:00,  3.05it/s]


228/256, train_avg_loss: 4.349


Training one epoch 229/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


229/256, train_avg_loss: 4.354


Training one epoch 230/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


230/256, train_avg_loss: 4.351


Training one epoch 231/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


231/256, train_avg_loss: 4.353


Training one epoch 232/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.55it/s]


232/256, train_avg_loss: 4.351, test_acc: 0.79698


Training one epoch 233/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


233/256, train_avg_loss: 4.352


Training one epoch 234/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


234/256, train_avg_loss: 4.351


Training one epoch 235/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


235/256, train_avg_loss: 4.350


Training one epoch 236/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


236/256, train_avg_loss: 4.352


Training one epoch 237/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


237/256, train_avg_loss: 4.350


Training one epoch 238/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


238/256, train_avg_loss: 4.348


Training one epoch 239/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


239/256, train_avg_loss: 4.350


Training one epoch 240/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


240/256, train_avg_loss: 4.348, test_acc: 0.79622


Training one epoch 241/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


241/256, train_avg_loss: 4.350


Training one epoch 242/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


242/256, train_avg_loss: 4.352


Training one epoch 243/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


243/256, train_avg_loss: 4.346


Training one epoch 244/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


244/256, train_avg_loss: 4.347


Training one epoch 245/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


245/256, train_avg_loss: 4.347


Training one epoch 246/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


246/256, train_avg_loss: 4.350


Training one epoch 247/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


247/256, train_avg_loss: 4.345


Training one epoch 248/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.59it/s]


248/256, train_avg_loss: 4.347, test_acc: 0.7959


Training one epoch 249/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


249/256, train_avg_loss: 4.345


Training one epoch 250/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


250/256, train_avg_loss: 4.343


Training one epoch 251/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


251/256, train_avg_loss: 4.343


Training one epoch 252/256: 100%|██████████| 98/98 [00:32<00:00,  3.06it/s]


252/256, train_avg_loss: 4.346


Training one epoch 253/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


253/256, train_avg_loss: 4.344


Training one epoch 254/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]


254/256, train_avg_loss: 4.344


Training one epoch 255/256: 100%|██████████| 98/98 [00:31<00:00,  3.07it/s]


255/256, train_avg_loss: 4.343


Training one epoch 256/256: 100%|██████████| 98/98 [00:31<00:00,  3.08it/s]
Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.53it/s]

256/256, train_avg_loss: 4.346, test_acc: 0.79732





In [264]:
model.eval()

torch.save(
    model.state_dict(), ctx.saved_w
)

eval_accuracy()

Evaling accuracy on test set: 100%|██████████| 98/98 [00:21<00:00,  4.56it/s]


0.79732