In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import itertools
from torchinfo import summary

In [3]:
class Residual(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.residual = nn.Sequential(*layers)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return x + self.gamma * self.residual(x)

class GlobalAvgPool(nn.Module):
    def forward(self, x):
        return x.mean(dim=-2)

In [4]:
class ShiftedWindowAttention(nn.Module):
    def __init__(self, dim, head_dim, shape, window_size, shift_size=0):
        super().__init__()
        self.heads = dim // head_dim
        self.head_dim = head_dim
        self.scale = head_dim**-0.5

        self.shape = shape
        self.window_size = window_size
        self.shift_size = shift_size

        self.to_qkv = nn.Linear(dim, dim * 3)
        self.unifyheads = nn.Linear(dim, dim)

        self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * window_size - 1)**2))
        self.register_buffer("relative_indices", self.get_indices(window_size))

        if shift_size > 0:
            self.register_buffer("mask", self.generate_mask(shape, window_size, shift_size))


    def forward(self, x):
        shift_size, window_size = self.shift_size, self.window_size

        x = self.to_windows(x, self.shape, window_size, shift_size) # partition into windows

        # self attention
        qkv = self.to_qkv(x).unflatten(-1, (3, self.heads, self.head_dim)).transpose(-2, 1)
        queries, keys, values = qkv.unbind(dim=2)

        att = queries @ keys.transpose(-2, -1)

        att = att * self.scale + self.get_rel_pos_enc(window_size) # add relative positon encoding

        # masking
        if shift_size > 0:
            att = self.mask_attention(att)

        att = F.softmax(att, dim=-1)

        x = att @ values
        x = x.transpose(1, 2).contiguous().flatten(-2, -1) # move head back
        x = self.unifyheads(x)

        x = self.from_windows(x, self.shape, window_size, shift_size) # undo partitioning into windows
        return x


    def to_windows(self, x, shape, window_size, shift_size):
        x = x.unflatten(1, shape)
        if shift_size > 0:
            x = x.roll((-shift_size, -shift_size), dims=(1, 2))
        x = self.split_windows(x, window_size)
        return x


    def from_windows(self, x, shape, window_size, shift_size):
        x = self.merge_windows(x, shape, window_size)
        if shift_size > 0:
            x = x.roll((shift_size, shift_size), dims=(1, 2))
        x = x.flatten(1, 2)
        return x


    def mask_attention(self, att):
        num_win = self.mask.size(1)
        att = att.unflatten(0, (att.size(0) // num_win, num_win))
        att = att.masked_fill(self.mask, float('-inf'))
        att = att.flatten(0, 1)
        return att


    def get_rel_pos_enc(self, window_size):
        indices = self.relative_indices.expand(self.heads, -1)
        rel_pos_enc = self.pos_enc.gather(-1, indices)
        rel_pos_enc = rel_pos_enc.unflatten(-1, (window_size**2, window_size**2))
        return rel_pos_enc


    # For explanation of mask regions see Figure 4 in the article
    @staticmethod
    def generate_mask(shape, window_size, shift_size):
        region_mask = torch.zeros(1, *shape, 1)
        slices = [slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)]

        region_num = 0
        for i in slices:
            for j in slices:
                region_mask[:, i, j, :] = region_num
                region_num += 1

        mask_windows = ShiftedWindowAttention.split_windows(region_mask, window_size).squeeze(-1)
        diff_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        mask = diff_mask != 0
        mask = mask.unsqueeze(1).unsqueeze(0) # add heads and batch dimension
        return mask


    @staticmethod
    def split_windows(x, window_size):
        n_h, n_w = x.size(1) // window_size, x.size(2) // window_size
        x = x.unflatten(1, (n_h, window_size)).unflatten(-2, (n_w, window_size)) # split into windows
        x = x.transpose(2, 3).flatten(0, 2) # merge batch and window numbers
        x = x.flatten(-3, -2)
        return x


    @staticmethod
    def merge_windows(x, shape, window_size):
        n_h, n_w = shape[0] // window_size, shape[1] // window_size
        b = x.size(0) // (n_h * n_w)
        x = x.unflatten(1, (window_size, window_size))
        x = x.unflatten(0, (b, n_h, n_w)).transpose(2, 3) # separate batch and window numbers
        x = x.flatten(1, 2).flatten(-3, -2) # merge windows
        return x


    @staticmethod
    def get_indices(window_size):
        x = torch.arange(window_size, dtype=torch.long)

        y1, x1, y2, x2 = torch.meshgrid(x, x, x, x, indexing='ij')
        indices = (y1 - y2 + window_size - 1) * (2 * window_size - 1) + x1 - x2 + window_size - 1
        indices = indices.flatten()

        return indices

In [5]:
class FeedForward(nn.Sequential):
    def __init__(self, dim, mult=4):
        hidden_dim = dim * mult
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

In [6]:
class TransformerBlock(nn.Sequential):
    def __init__(self, dim, head_dim, shape, window_size, shift_size=0, p_drop=0.):
        super().__init__(
            Residual(
                nn.LayerNorm(dim),
                ShiftedWindowAttention(dim, head_dim, shape, window_size, shift_size),
                nn.Dropout(p_drop)
            ),
            Residual(
                nn.LayerNorm(dim),
                FeedForward(dim),
                nn.Dropout(p_drop)
            )
        )

In [7]:
class PatchMerging(nn.Module):
    def __init__(self, in_dim, out_dim, shape):
        super().__init__()
        self.shape = shape
        self.norm = nn.LayerNorm(4 * in_dim)
        self.reduction = nn.Linear(4 * in_dim, out_dim, bias=False)

    def forward(self, x):
        x = x.unflatten(1, self.shape).movedim(-1, 1)
        x = F.unfold(x, kernel_size=2, stride=2).movedim(1, -1)

        x = self.norm(x)
        x = self.reduction(x)
        return x

In [8]:
class Stage(nn.Sequential):
    def __init__(self, num_blocks, in_dim, out_dim, head_dim, shape, window_size, p_drop=0.):
        if out_dim != in_dim:
            layers = [PatchMerging(in_dim, out_dim, shape)]
            shape = (shape[0] // 2, shape[1] // 2)
        else:
            layers = []

        shift_size = window_size // 2
        layers += [TransformerBlock(out_dim, head_dim, shape, window_size, 0 if (num % 2 == 0) else shift_size,
                                    p_drop) for num in range(num_blocks)]

        super().__init__(*layers)

In [9]:
class StageStack(nn.Sequential):
    def __init__(self, num_blocks_list, dims, head_dim, shape, window_size, p_drop=0.):
        layers = []
        in_dim = dims[0]
        for num, out_dim in zip(num_blocks_list, dims[1:]):
            layers.append(Stage(num, in_dim, out_dim, head_dim, shape, window_size, p_drop))
            if in_dim != out_dim:
                shape = (shape[0] // 2, shape[1] // 2)
                in_dim = out_dim

        super().__init__(*layers)

In [10]:
class ToPatches(nn.Module):
    def __init__(self, in_channels, dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        patch_dim = in_channels * patch_size**2
        self.proj = nn.Linear(patch_dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).movedim(1, -1)
        x = self.proj(x)
        x = self.norm(x)
        return x

In [11]:
class AddPositionEmbedding(nn.Module):
    def __init__(self, dim, num_patches):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.Tensor(num_patches, dim))

    def forward(self, x):
        return x + self.pos_embedding

In [12]:
class ToEmbedding(nn.Sequential):
    def __init__(self, in_channels, dim, patch_size, num_patches, p_drop=0.):
        super().__init__(
            ToPatches(in_channels, dim, patch_size),
            AddPositionEmbedding(dim, num_patches),
            nn.Dropout(p_drop)
        )

In [13]:
class Head(nn.Sequential):
    def __init__(self, dim, classes, p_drop=0.):
        super().__init__(
            nn.LayerNorm(dim),
            nn.GELU(),
            GlobalAvgPool(),
            nn.Dropout(p_drop),
            nn.Linear(dim, classes)
        )


In [14]:
class SwinTransformer(nn.Sequential):
    def __init__(self, classes, image_size, num_blocks_list, dims, head_dim, patch_size, window_size,
                 in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.):
        reduced_size = image_size // patch_size
        shape = (reduced_size, reduced_size)
        num_patches = shape[0] * shape[1]

        super().__init__(
            ToEmbedding(in_channels, dims[0], patch_size, num_patches, emb_p_drop),
            StageStack(num_blocks_list, dims, head_dim, shape, window_size, trans_p_drop),
            Head(dims[-1], classes, head_p_drop)
        )
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.)
                nn.init.zeros_(m.bias)
            elif isinstance(m, AddPositionEmbedding):
                nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
            elif isinstance(m, ShiftedWindowAttention):
                nn.init.normal_(m.pos_enc, mean=0.0, std=0.02)
            elif isinstance(m, Residual):
                nn.init.zeros_(m.gamma)

    def separate_parameters(self):
        parameters_decay = set()
        parameters_no_decay = set()
        modules_weight_decay = (nn.Linear, )
        modules_no_weight_decay = (nn.LayerNorm,)

        for m_name, m in self.named_modules():
            for param_name, param in m.named_parameters():
                full_param_name = f"{m_name}.{param_name}" if m_name else param_name

                if isinstance(m, modules_no_weight_decay):
                    parameters_no_decay.add(full_param_name)
                elif param_name.endswith("bias"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, Residual) and param_name.endswith("gamma"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, ShiftedWindowAttention) and param_name.endswith("pos_enc"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, modules_weight_decay):
                    parameters_decay.add(full_param_name)

        # sanity check
        assert len(parameters_decay & parameters_no_decay) == 0
        assert len(parameters_decay) + len(parameters_no_decay) == len(list(self.parameters()))

        return parameters_decay, parameters_no_decay

In [15]:
model = SwinTransformer(37, 224,
                        num_blocks_list=[4, 4], dims=[128, 128, 256],
                        head_dim=32, patch_size=2, window_size=4,
                        emb_p_drop=0., trans_p_drop=0., head_p_drop=0.3)
print(model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

batch_size = 32
learning_rate = 0.001
num_epochs = 10

SwinTransformer(
  (0): ToEmbedding(
    (0): ToPatches(
      (proj): Linear(in_features=12, out_features=128, bias=True)
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): AddPositionEmbedding()
    (2): Dropout(p=0.0, inplace=False)
  )
  (1): StageStack(
    (0): Stage(
      (0): TransformerBlock(
        (0): Residual(
          (residual): Sequential(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): ShiftedWindowAttention(
              (to_qkv): Linear(in_features=128, out_features=384, bias=True)
              (unifyheads): Linear(in_features=128, out_features=128, bias=True)
            )
            (2): Dropout(p=0.0, inplace=False)
          )
        )
        (1): Residual(
          (residual): Sequential(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): FeedForward(
              (0): Linear(in_features=128, out_features=512, bias=True)
              (1): GELU(

In [16]:
print(summary(model, input_size=(32, 3, 224, 224)))

Layer (type:depth-idx)                                       Output Shape              Param #
SwinTransformer                                              [32, 37]                  --
├─ToEmbedding: 1-1                                           [32, 12544, 128]          --
│    └─ToPatches: 2-1                                        [32, 12544, 128]          --
│    │    └─Linear: 3-1                                      [32, 12544, 128]          1,664
│    │    └─LayerNorm: 3-2                                   [32, 12544, 128]          256
│    └─AddPositionEmbedding: 2-2                             [32, 12544, 128]          1,605,632
│    └─Dropout: 2-3                                          [32, 12544, 128]          --
├─StageStack: 1-2                                            [32, 3136, 256]           --
│    └─Stage: 2-4                                            [32, 12544, 128]          --
│    │    └─TransformerBlock: 3-3                            [32, 12544, 128]       

In [17]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [18]:
trainval_data = datasets.OxfordIIITPet(root="data", split="trainval", target_types="category", download=True, transform=transform)
test_data = datasets.OxfordIIITPet(root="data", split="test", target_types="category", download=True, transform=transform)
combined_data = ConcatDataset([trainval_data, test_data])

train_size = int(0.7 * len(combined_data))
val_size = int(0.15 * len(combined_data))
test_size = len(combined_data) - train_size - val_size
train_data, val_data, test_data = random_split(combined_data, [train_size, val_size, test_size])

Downloading https://thor.robots.ox.ac.uk/pets/images.tar.gz to data/oxford-iiit-pet/images.tar.gz


100%|██████████| 792M/792M [00:38<00:00, 20.8MB/s]


Extracting data/oxford-iiit-pet/images.tar.gz to data/oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/pets/annotations.tar.gz to data/oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19.2M/19.2M [00:01<00:00, 10.0MB/s]


Extracting data/oxford-iiit-pet/annotations.tar.gz to data/oxford-iiit-pet


In [19]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

Train set size: 5144
Validation set size: 1102
Test set size: 1103


In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [21]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    print(f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {accuracy:.2f}%")

In [22]:
def evaluate(model, data_loader, criterion, device, phase="Validation"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"{phase}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"{phase} Loss: {epoch_loss:.4f}, {phase} Accuracy: {accuracy:.2f}%")

In [23]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    train(model, train_loader, criterion, optimizer, device)
    evaluate(model, val_loader, criterion, device, phase="Validation")


Epoch 1/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.65it/s]


Train Loss: 3.6775, Train Accuracy: 3.97%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.32it/s]


Validation Loss: 3.5503, Validation Accuracy: 5.63%

Epoch 2/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.5836, Train Accuracy: 5.09%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.35it/s]


Validation Loss: 3.5018, Validation Accuracy: 6.81%

Epoch 3/10


Training: 100%|██████████| 161/161 [01:36<00:00,  1.66it/s]


Train Loss: 3.5457, Train Accuracy: 5.58%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.34it/s]


Validation Loss: 3.4829, Validation Accuracy: 6.53%

Epoch 4/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.5003, Train Accuracy: 6.43%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.37it/s]


Validation Loss: 3.4824, Validation Accuracy: 7.08%

Epoch 5/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.4529, Train Accuracy: 7.80%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.34it/s]


Validation Loss: 3.4074, Validation Accuracy: 8.53%

Epoch 6/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.65it/s]


Train Loss: 3.3886, Train Accuracy: 8.90%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.36it/s]


Validation Loss: 3.3070, Validation Accuracy: 11.89%

Epoch 7/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.3045, Train Accuracy: 10.83%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.32it/s]


Validation Loss: 3.2842, Validation Accuracy: 10.44%

Epoch 8/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.2386, Train Accuracy: 11.78%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.33it/s]


Validation Loss: 3.2107, Validation Accuracy: 12.43%

Epoch 9/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.1806, Train Accuracy: 12.71%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.36it/s]


Validation Loss: 3.1727, Validation Accuracy: 12.70%

Epoch 10/10


Training: 100%|██████████| 161/161 [01:37<00:00,  1.66it/s]


Train Loss: 3.1565, Train Accuracy: 12.81%


Validation: 100%|██████████| 35/35 [00:10<00:00,  3.33it/s]

Validation Loss: 3.1546, Validation Accuracy: 12.79%





In [24]:
print("\nFinal Test Evaluation")
evaluate(model, test_loader, criterion, device, phase="Test")


Final Test Evaluation


Test: 100%|██████████| 35/35 [00:10<00:00,  3.27it/s]

Test Loss: 3.1781, Test Accuracy: 15.32%





# timm SWIN

In [25]:
timm_swin = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=37)
print(summary(timm_swin, input_size=(32, 3, 224, 224)))

Layer (type:depth-idx)                             Output Shape              Param #
SwinTransformer                                    [32, 37]                  --
├─PatchEmbed: 1-1                                  [32, 56, 56, 128]         --
│    └─Conv2d: 2-1                                 [32, 128, 56, 56]         6,272
│    └─LayerNorm: 2-2                              [32, 56, 56, 128]         256
├─Sequential: 1-2                                  [32, 7, 7, 1024]          --
│    └─SwinTransformerStage: 2-3                   [32, 56, 56, 128]         --
│    │    └─Identity: 3-1                          [32, 56, 56, 128]         --
│    │    └─Sequential: 3-2                        [32, 56, 56, 128]         397,896
│    └─SwinTransformerStage: 2-4                   [32, 28, 28, 256]         --
│    │    └─PatchMerging: 3-3                      [32, 28, 28, 256]         132,096
│    │    └─Sequential: 3-4                        [32, 28, 28, 256]         1,582,224
│    └─SwinTra

In [26]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    train(timm_swin, train_loader, criterion, optimizer, device)
    evaluate(timm_swin, val_loader, criterion, device, phase="Validation")


Epoch 1/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s]


Train Loss: 3.7397, Train Accuracy: 2.62%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.33it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 2/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s]


Train Loss: 3.7452, Train Accuracy: 2.31%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.35it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 3/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s]


Train Loss: 3.7394, Train Accuracy: 2.59%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.29it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 4/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s]


Train Loss: 3.7425, Train Accuracy: 2.33%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.36it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 5/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s]


Train Loss: 3.7431, Train Accuracy: 2.29%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.37it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 6/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s]


Train Loss: 3.7440, Train Accuracy: 2.45%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.36it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 7/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s]


Train Loss: 3.7459, Train Accuracy: 2.33%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.24it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 8/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s]


Train Loss: 3.7381, Train Accuracy: 2.45%


Validation: 100%|██████████| 35/35 [00:08<00:00,  4.34it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 9/10


Training: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s]


Train Loss: 3.7387, Train Accuracy: 2.31%


Validation: 100%|██████████| 35/35 [00:07<00:00,  4.39it/s]


Validation Loss: 3.6961, Validation Accuracy: 3.54%

Epoch 10/10


Training: 100%|██████████| 161/161 [01:05<00:00,  2.45it/s]


Train Loss: 3.7445, Train Accuracy: 2.53%


Validation: 100%|██████████| 35/35 [00:07<00:00,  4.39it/s]

Validation Loss: 3.6961, Validation Accuracy: 3.54%





In [27]:
print("\nFinal Test Evaluation")
evaluate(timm_swin, test_loader, criterion, device, phase="Test")


Final Test Evaluation


Test: 100%|██████████| 35/35 [00:08<00:00,  4.21it/s]

Test Loss: 3.7442, Test Accuracy: 3.26%



