<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/lucid_perceiverio_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### perceiveriornn

#### setup

In [1]:
# https://arxiv.org/pdf/2107.14795.pdf
# https://github.com/lucidrains/perceiver-pytorch
!pip install einops
from math import pi, log
from functools import wraps
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
device = "cuda" if torch.cuda.is_available() else "cpu"


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


#### helpers

In [2]:
# helpers
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    # @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# helper classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context = None, mask = None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)
        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)



#### PerceiverIO

In [3]:
# @title PerceiverIO class save
class PerceiverIO(nn.Module):
    def __init__(
        self,
        *,
        depth,
        dim,
        queries_dim,
        logits_dim = None,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
            PreNorm(latent_dim, FeedForward(latent_dim))
        ])
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}
        for i in range(depth):
            self.layers.append(nn.ModuleList([get_latent_attn(**cache_args), get_latent_ff(**cache_args)]))
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()

    def forward(self, data, mask = None, queries = None):
        b, *_, device = *data.shape, data.device
        x = repeat(self.latents, 'n d -> b n d', b = b)
        cross_attn, cross_ff = self.cross_attend_blocks
        # cross attention only happens once for Perceiver IO
        x = cross_attn(x, context = data, mask = mask) + x
        x = cross_ff(x) + x
        # layers
        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x
        if not exists(queries):
            return x
        # make sure queries contains batch dimension
        if queries.ndim == 2:
            queries = repeat(queries, 'n d -> b n d', b = b)
        # cross attend from decoder queries to latents
        latents = self.decoder_cross_attn(queries, context = x)
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
        return self.to_logits(latents)

def preprocess(X):
    if X.dim()==1:
        X=X.unsqueeze(dim=0)
    X=X.flatten(start_dim=1, end_dim=-1) #(start_dim=1)
    X=X.unsqueeze(dim=1)
    return X

def postprocess(logits):
    if logits.dim()==3:
        logits=logits.squeeze(dim=1)
    return logits


In [4]:
# @title PerceiverIO model save
model = PerceiverIO(
    dim = 28*28,                    # dimension of sequence to be encoded
    queries_dim = 10,            # dimension of decoder queries
    logits_dim = None,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 128,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 128,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
).to(device)

seq = torch.randn(5, 1, 28*28, device=device)
queries = torch.zeros(1, 10, device=device)

seq = preprocess(seq) #[512,1,10]
logits = model(seq, queries = None) # 
# logits = model(seq, queries = queries) # 
logprobs = postprocess(logits) #[512, 4]
print(logprobs.shape)

# none
# in forward torch.Size([128, 128]) 5
# torch.Size([5, 128, 128])
# torch.Size([5, 128, 128])
# torch.Size([5, 128, 128])
# torch.Size([5, 128, 128])

# queries
# in forward torch.Size([128, 128]) 5
# torch.Size([5, 128, 128])
# torch.Size([5, 128, 128])
# torch.Size([5, 128, 128])
# torch.Size([5, 10])


torch.Size([5, 128, 128])


#### PerceiverIOrnn

In [5]:

class PerceiverIOrnn(nn.Module):
    def __init__(
        self,
        *,
        depth,
        dim,
        queries_dim,
        logits_dim = None,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False
    ):
        super().__init__()
        # self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
        self.latents = torch.zeros(num_latents, latent_dim)
        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
            PreNorm(latent_dim, FeedForward(latent_dim))
        ])
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}
        for i in range(depth):
            self.layers.append(nn.ModuleList([get_latent_attn(**cache_args), get_latent_ff(**cache_args)]))
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()

    def forward(self, data, mask = None, queries = None, x = None):
        b, *_, device = *data.shape, data.device
        if x == None: x = repeat(self.latents, 'n d -> b n d', b = b).to(device)
        cross_attn, cross_ff = self.cross_attend_blocks
        # cross attention only happens once for Perceiver IO
        x = cross_attn(x, context = data, mask = mask) + x
        x = cross_ff(x) + x
        # layers
        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x
        if not exists(queries):
            return x
        # make sure queries contains batch dimension
        if queries.ndim == 2:
            queries = repeat(queries, 'n d -> b n d', b = b)
        # cross attend from decoder queries to latents
        latents = self.decoder_cross_attn(queries, context = x)
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
        # return self.to_logits(latents)
        return x, self.to_logits(latents)

def preprocess(X):
    if X.dim()==1:
        X=X.unsqueeze(dim=0)
    X=X.flatten(start_dim=1, end_dim=-1) #(start_dim=1)
    X=X.unsqueeze(dim=1)
    return X

def postprocess(logits):
    if logits.dim()==3:
        logits=logits.squeeze(dim=1)
    return logits


#### wwwwwwwwwwwwwww

In [None]:

# input_size – The number of expected features in the input x
# hidden_size – The number of features in the hidden state h
# num_layers – Number of recurrent layers

# rnn = nn.LSTM(10, 20, 2) # (input_size, hidden_size, num_layers)
# input = torch.randn(5, 3, 10) # batch, input_size
# h0 = torch.randn(2, 3, 20)
# c0 = torch.randn(2, 3, 20)
# output, (hn, cn) = rnn(input, (h0, c0))

# rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
# input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
# hx = torch.randn(3, 20) # (batch, hidden_size)
# cx = torch.randn(3, 20)
# output = []
# for i in range(input.size()[0]):
#     hx, cx = rnn(input[i], (hx, cx))
#     output.append(hx)
# output = torch.stack(output, dim=0)


# seq = torch.randn(5, 1, 28,28, device=device) # batch, rgb, h, w
# # b, in_channels, *axis, device, dtype = *data.shape, data.device, data.dtype # 4 [224, 224] cpu torch.float32
# b, in_channels, h,w, device, dtype = *data.shape, data.device, data.dtype # 4 [224, 224] cpu torch.float32

# # in_shape # mario (240, 256)
# self.lstm = nn.LSTMCell(in_shape[1], 256)
# hx = torch.zeros(1, 256).to(device)
# a3c_hx1, a3c_cx1 = self.lstm(vec_st, (a3c_hx, a3c_cx)) # [1, 320], ([1, 256],[1, 256])

input_size = 28
seq = torch.randn(1, input_size, device=device)

model = PerceiverIOrnn(
    # dim = h*w,                    # dimension of sequence to be encoded
    dim = input_size,                    # dimension of sequence to be encoded
    queries_dim = 10,            # dimension of decoder queries
    logits_dim = None,            # dimension of final logits
    depth = 1,                   # depth of net
    num_latents = 128,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 128,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
).to(device)

seq = preprocess(seq) #[512,1,10]
latent = model(seq, queries = None) # 
# x = nn.Parameter(torch.randn(num_latents=512, latent_dim=512))
for x in range(5):
    # seq = torch.randn(1, 28*28, device=device)
    seq = torch.randn(1, input_size, device=device)
    seq = preprocess(seq) #[512,1,10]
    latent = model(seq, queries = None, x=latent) # 
queries = torch.zeros(1, 10, device=device)
latent, logits = model(seq, queries = queries) # 
print(logits)
print(logits.shape)
logprobs = postprocess(logits) #[512, 4]
print(logprobs.shape)



tensor([[[ 0.0194, -0.2600, -0.0958,  0.1818, -0.2351,  1.0639,  0.5648,
           0.1742, -0.0940,  0.1996]]], grad_fn=<AddBackward0>)
torch.Size([1, 1, 10])
torch.Size([1, 10])


### rnn

#### rnn setup

In [7]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# https://github.com/python-engineer/pytorch-examples/blob/master/rnn-lstm-gru/main.py

train_data = torchvision.datasets.FashionMNIST(root="data", train=True, download=True,transform=transforms.ToTensor(),)
test_data = torchvision.datasets.FashionMNIST(root="data", train=False, download=True, transform=transforms.ToTensor(),) #opt no download
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)#, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)#, shuffle=False)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cuda" if torch.cuda.is_available() else "cpu"

input_size = 28
sequence_length = 28
hidden_size = 128
num_layers = 2
num_classes = 10



Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw



#### rnn lstm

In [16]:
# rnn lstm
# Fully connected neural network with one hidden layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        # self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # -> x needs to be: (batch_size, seq, input_size)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        # x: (n, 28, 28), h0: (2, n, 128)
        out, _ = self.rnn(x, h0)
        # out, _ = self.lstm(x, (h0,c0))
        # out:(batch_size, seq_length, hidden_size) (n, 28, 128)
        out = out[:, -1, :] # out: (n, 128)
        out = self.fc(out) # out: (n, 10)
        return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# print(model)


#### PIORNN

In [27]:

class PIORNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(PIORNN, self).__init__()
        self.cell = PerceiverIOrnn(
            dim = input_size,                    # dimension of sequence to be encoded
            queries_dim = num_classes,            # dimension of decoder queries
            logits_dim = None,            # dimension of final logits
            depth = 1,                   # depth of net
            num_latents = 16,           # number of latents, or induced set points, or centroids. different papers giving it different names
            latent_dim = 16,            # latent dimension
            cross_heads = 1,             # number of heads for cross attention. paper said 1
            latent_heads = 4,            # number of heads for latent self attention, 8
            cross_dim_head = 8,         # number of dimensions per cross attention head
            latent_dim_head = 8,        # number of dimensions per latent self attention head
            weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
            )#.to(device)
    def postprocess(self, logits):
        if logits.dim()==3:
            logits=logits.squeeze(dim=1)
        return logits
    def forward(self, input): # input = torch.randn(64, 28, 28, device=device)
        if input.dim()==2:
            input=input.unsqueeze(0)
        if input.dim() not in [2,3]: print("erm")
        b,h,w = input.shape
        latent=None
        for i in range(h):  # dim=1
            seq = input[:, i, :]
            seq = preprocess(seq) #[b,1,num_classes]
            latent = self.cell(seq, queries = None, x=latent) # 
        queries = torch.zeros(1, 10, device=device)
        queries[-1]=1
        latent, logits = self.cell(seq, queries = queries) # 
        logprobs = postprocess(logits) #
        pred_probab = nn.Softmax(dim=1)(logprobs)
        outputs = pred_probab # [b, num_classes]
        # y_pred = pred_probab.argmax(1).float() # [b]
        return outputs

model = PIORNN(28, 10).to(device)
input = torch.randn(64, 28, 28, device=device)

output=model(input)
# print(output.shape)




#### train/load function

In [28]:
loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(train_loader, model, loss_fn, optimizer):
    n_total_steps = len(train_loader)
    size = len(train_loader.dataset)
    model.train()
    # https://stackoverflow.com/questions/69428646/pytorch-gpu-what-am-i-forgetting-to-move-over-to-the-gpu
    for batch, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device) # origin shape: [N, 1, 28, 28] resized: [N, 28, 28]
        labels = labels.to(device)
        outputs = model(images)
        # print(outputs.shape, labels.shape)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(images)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(test_loader, model, loss_fn):
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for X, y in test_loader:
            X = X.reshape(-1, sequence_length, input_size).to(device)
            y = y.to(device)
            outputs = model(X)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += y.size(0)
            n_correct += (predicted == y).sum().item()
        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network on the 10000 test images: {acc} %')

# def test(test_loader, model, loss_fn):
#     size = len(test_loader.dataset)
#     num_batches = len(test_loader)
#     model.eval()
#     test_loss, correct = 0, 0
#     with torch.no_grad():
#         for X, y in test_loader:
#             X, y = X.to(device), y.to(device)
#             X = torch.squeeze(X)
#             pred = model(X)
#             test_loss += loss_fn(pred, y).item()
#             correct += (pred.argmax(1) == y).type(torch.float).sum().item()
#     test_loss /= num_batches
#     correct /= size
#     print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

train(train_loader, model, loss_fn, optimizer)
test(test_loader, model, loss_fn)

loss: 2.300376  [    0/60000]
loss: 2.034185  [ 6400/60000]
loss: 2.066639  [12800/60000]
loss: 2.114356  [19200/60000]
loss: 2.025556  [25600/60000]
loss: 2.106337  [32000/60000]
loss: 2.060969  [38400/60000]
loss: 2.064992  [44800/60000]
loss: 2.056250  [51200/60000]
loss: 2.075840  [57600/60000]
Accuracy of the network on the 10000 test images: 36.13 %


In [30]:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer)
    test(train_loader, model, loss_fn)
print("Done!")
# torch.save(model.state_dict(), "model.pth")
# print("Saved PyTorch Model State to model.pth")
# model = NeuralNetwork()
# model.load_state_dict(torch.load("model.pth"))

# og rnn 10000 test images: 86.515 %

Epoch 1
-------------------------------
loss: 2.030659  [    0/60000]
loss: 2.005738  [ 6400/60000]
loss: 2.015051  [12800/60000]
loss: 2.123432  [19200/60000]
loss: 2.025098  [25600/60000]
loss: 2.116124  [32000/60000]
loss: 1.981675  [38400/60000]
loss: 1.996964  [44800/60000]
loss: 2.005044  [51200/60000]
loss: 2.024067  [57600/60000]
Accuracy of the network on the 10000 test images: 40.645 %
Epoch 2
-------------------------------
loss: 2.019481  [    0/60000]
loss: 2.026058  [ 6400/60000]
loss: 1.998628  [12800/60000]
loss: 2.151947  [19200/60000]
loss: 2.029094  [25600/60000]
loss: 2.029098  [32000/60000]
loss: 1.982815  [38400/60000]
loss: 1.989564  [44800/60000]
loss: 2.032305  [51200/60000]
loss: 2.026143  [57600/60000]
Accuracy of the network on the 10000 test images: 41.37 %
Epoch 3
-------------------------------
loss: 2.012411  [    0/60000]
loss: 2.026163  [ 6400/60000]
loss: 2.095206  [12800/60000]
loss: 2.112392  [19200/60000]
loss: 2.046622  [25600/60000]
loss: 2.02449

In [None]:
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",]

model.eval()

import random
n=random.randint(0,1000)
print(n)
x, y = test_data[n][0], test_data[n][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')


331
Predicted: "Sneaker", Actual: "Sneaker"


In [None]:
# https://gmihaila.github.io/tutorial_notebooks/gpt2_finetune_classification/
# https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/gpt2#transformers.GPT2ForSequenceClassification
dataloader=train_loader
predictions_labels = []
true_labels = []
total_loss = 0
model.train() # set model to train mode
# for batch in tqdm(dataloader, total=len(dataloader)):
# for batch, (images, labels) in enumerate(dataloader):
for batch, labels in enumerate(dataloader):
    # true_labels += batch['labels'].numpy().flatten().tolist()
    # batch = {k:v.type(torch.long).to(device) for k,v in batch.items()} # move batch to device
    model.zero_grad() # clear previously calculated gradients before backward pass.
    # print(batch.items())
    # print(batch.keys())
    # print(batch['labels'])
    # print(batch['attention_mask'])
    # print(batch['input_ids'])
    outputs = model(labels)
    # outputs = model(**batch)
    loss, logits = outputs[:2]
    total_loss += loss.item() #`.item()` returns the Python value from the tensor.
    loss.backward() # back propagate loss
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip norm of gradients to 1.0 to help prevent the "exploding gradients"
    optimizer.step()
    logits = logits.detach().cpu().numpy()
    predictions_labels += logits.argmax(axis=-1).flatten().tolist()
avg_epoch_loss = total_loss / len(dataloader)