In [4]:
!pip install torchmetrics einops

Collecting torchmetrics
  Using cached torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m926.4/926.4 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.9 torchmetrics-1.6.0


In [5]:
import torch
from torch import Tensor, nn, optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import datasets

from einops import rearrange

from torch.utils.data import DataLoader, Dataset, random_split

import matplotlib.pyplot as plt
import numpy as np

from tqdm.auto import tqdm

from torchmetrics import Accuracy, F1Score

In [6]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [26]:
transform = T.Compose([
  T.ToTensor(),
  T.Normalize((0.5,), (0.5,)),
  T.RandomHorizontalFlip(),
])

train_dataset = datasets.FashionMNIST(
  root='./data',
  train=True,
  transform=transform,
  download=True,
)

test_dataset = datasets.FashionMNIST(
  root='./data',
  train=False,
  transform=transform,
  download=False,
)

# Crear dataloaders para iterar sobre los datasets
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [8]:
images, labels = next(iter(train_loader))
images.shape, labels.shape

(torch.Size([8, 1, 28, 28]), torch.Size([8]))

In [9]:
!curl https://raw.githubusercontent.com/ggonzalesd/swin-transformer-classification/refs/heads/main/m_down.py --output m_down.py
!curl https://raw.githubusercontent.com/ggonzalesd/swin-transformer-classification/refs/heads/main/m_conv.py --output m_conv.py
!curl https://raw.githubusercontent.com/ggonzalesd/swin-transformer-classification/refs/heads/main/m_swing.py --output m_swing.py
!curl https://raw.githubusercontent.com/ggonzalesd/swin-transformer-classification/refs/heads/main/m_poolattention.py --output m_poolattention.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   762  100   762    0     0   2936      0 --:--:-- --:--:-- --:--:--  2942
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1235  100  1235    0     0   6722      0 --:--:-- --:--:-- --:--:--  6748
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  3122  100  3122    0     0  10682      0 --:--:-- --:--:-- --:--:-- 10691
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1200  100  1200    0     0   4982      0 --:--:-- --:--:-- --:--:--  4979


In [10]:
from m_down import DownBlock
from m_conv import Convolutional
from m_swing import SwinTransformer
from m_poolattention import PoolAttention

In [30]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = Convolutional(1, 32)
    self.vit1 = SwinTransformer(14, 32, 4, 4)
    self.down1 = DownBlock(32, 64)
    self.vit2 = SwinTransformer(7, 64, 4, 4)
    self.down2 = DownBlock(64, 128)
    self.pool = PoolAttention(128)
    self.output = nn.Linear(128, 10)

  def forward(self, x:Tensor) -> Tensor:
    x = self.conv(x)
    x = self.vit1(x.moveaxis(1, -1))
    x = self.down1(x.moveaxis(-1, 1))
    x = self.vit2(x.moveaxis(1, -1))
    x = self.down2(x.moveaxis(-1, 1))

    x = rearrange(x, 'b c h w -> b (h w) c')

    x = self.pool(x)
    x = self.output(F.silu(x))
    return x

model = Model().to(DEVICE)

In [31]:
optimizer = optim.AdamW(model.parameters(), 1e-4, weight_decay=1e-8)
loss_fn = nn.CrossEntropyLoss()
accuracy_metric = Accuracy(num_classes=10, average='weighted', task='multiclass').to(DEVICE)
f1_metric = F1Score(num_classes=10, average='weighted', task='multiclass').to(DEVICE)

In [32]:
def train(model:nn.Module, epoch:int):
  model.train()
  total_loss = 0
  total = 0
  accuracy_metric.reset()
  f1_metric.reset()

  bar = tqdm(total=len(train_loader), desc=f'Train ({epoch}/10)', leave=True, colour='blue')
  for images, labels in train_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    logits = model(images)
    loss:Tensor = loss_fn(logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total += images.shape[0]
    total_loss += loss.item() * images.shape[0]
    preds = torch.argmax(logits, dim=1)
    accuracy_metric.update(preds, labels)
    f1_metric.update(preds, labels)

    bar.set_postfix(loss = loss.item())
    bar.update(1)

  total_loss /= total
  total_accuracy = accuracy_metric.compute().item()
  total_f1 = f1_metric.compute().item()

  bar.set_postfix(loss=total_loss, accuracy=total_accuracy, f1=total_f1)

  bar.update()
  bar.close()

In [33]:
def test(model:nn.Module, epoch:int):
  model.eval()
  total_loss = 0
  total = 0
  accuracy_metric.reset()
  f1_metric.reset()

  bar = tqdm(total=len(test_loader), desc=f'Test ({epoch}/10)', leave=True, colour='yellow')
  with torch.inference_mode():
    for images, labels in test_loader:
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)

      logits = model(images)
      loss:Tensor = loss_fn(logits, labels)

      total += images.shape[0]
      total_loss += loss.item() * images.shape[0]
      preds = torch.argmax(logits, dim=1)
      accuracy_metric.update(preds, labels)
      f1_metric.update(preds, labels)

      bar.set_postfix(loss = loss.item())
      bar.update(1)

  total_loss /= total
  total_accuracy = accuracy_metric.compute().item()
  total_f1 = f1_metric.compute().item()

  bar.set_postfix(loss=total_loss, accuracy=total_accuracy, f1=total_f1)

  bar.update()
  bar.close()

In [34]:
for epoch in range(10):
  train(model, epoch+1)
  test(model, epoch+1)


Train (1/10):   0%|          | 0/468 [00:00<?, ?it/s]

Test (1/10):   0%|          | 0/79 [00:00<?, ?it/s]

Train (2/10):   0%|          | 0/468 [00:00<?, ?it/s]

Test (2/10):   0%|          | 0/79 [00:00<?, ?it/s]

Train (3/10):   0%|          | 0/468 [00:00<?, ?it/s]

Test (3/10):   0%|          | 0/79 [00:00<?, ?it/s]

Train (4/10):   0%|          | 0/468 [00:00<?, ?it/s]

Test (4/10):   0%|          | 0/79 [00:00<?, ?it/s]

Train (5/10):   0%|          | 0/468 [00:00<?, ?it/s]

Test (5/10):   0%|          | 0/79 [00:00<?, ?it/s]

Train (6/10):   0%|          | 0/468 [00:00<?, ?it/s]

KeyboardInterrupt: 