In [1]:
pip install ptflops

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->ptflops)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->ptflops)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->ptflops)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.0->ptflops)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.0->ptflops)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.0->ptflops)
  Downloading nvidia_cusp

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd


import torch
import time
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class DynamicPosBias(nn.Module):
    # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
    """ Dynamic Relative Position Bias.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of heads for spatial self-correlation.
        residual (bool):  If True, use residual strage to connect conv.
    """
    def __init__(self, dim, num_heads, residual):
        super().__init__()
        self.residual = residual
        self.num_heads = num_heads
        self.pos_dim = dim // 4
        self.pos_proj = nn.Linear(2, self.pos_dim)
        self.pos1 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim),
        )
        self.pos2 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim)
        )
        self.pos3 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.num_heads)
        )
    def forward(self, biases):
        if self.residual:
            pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
            pos = pos + self.pos1(pos)
            pos = pos + self.pos2(pos)
            pos = self.pos3(pos)
        else:
            pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
        return pos
        
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class DFE(nn.Module):
    """ Dual Feature Extraction 
    Args:
        in_features (int): Number of input channels.
        out_features (int): Number of output channels.
    """
    def __init__(self, in_features, out_features):
        super().__init__()

        self.out_features = out_features

        self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
                        nn.LeakyReLU(negative_slope=0.2, inplace=True),
                        nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
                        nn.LeakyReLU(negative_slope=0.2, inplace=True),
                        nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
        
        self.linear = nn.Conv2d(in_features, out_features,1,1,0)

    def forward(self, x, x_size):
        
        B, L, C = x.shape
        H, W = x_size
        x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
        x = self.conv(x) * self.linear(x)
        x = x.view(B, -1, H*W).permute(0,2,1).contiguous()

        return x

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):
        super(ConvLayer, self).__init__()
        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=bias,
        )
        self.norm = norm(num_features=out_channels) if norm else None
        self.act = act_func() if act_func else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x


class LinearAttention(nn.Module):

    def __init__(self, dim, input_resolution, num_heads, base_win_size=(2,2), qkv_bias=True, value_drop = 0., proj_drop= 0., **kwargs):

        super().__init__()
        # parameters
        self.dim = dim
        self.window_size = input_resolution 
        self.num_heads = num_heads
        self.base_win_size = base_win_size

        # feature projection
        self.qv = DFE(dim, dim)
        self.proj = nn.Linear(dim, dim)

        # dropout
        self.value_drop = nn.Dropout(value_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        

        # normalization factor and spatial linear layer for S-SC
        head_dim = dim // (2*num_heads)
        self.scale = head_dim
        self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)

        # define a parameter table of relative position bias
        self.H_sp, self.W_sp = self.window_size
        self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)

    def spatial_linear_projection(self, x):
        B, num_h, L, C = x.shape
        H, W = self.window_size
        map_H, map_W = self.base_win_size

        x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
        x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
        return x
    
    def spatial_self_correlation(self, q, v):
        
        B, num_head, L, C = q.shape

        # spatial projection
        v = self.spatial_linear_projection(v)

        # compute correlation map
        corr_map = (q @ v.transpose(-2,-1)) / self.scale

        # add relative position bias
        # generate mother-set
        position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
        position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
        biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
        rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
        pos = self.pos(rpe_biases)

        # select position bias
        coords_h = torch.arange(self.H_sp, device=v.device)
        coords_w = torch.arange(self.W_sp, device=v.device)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.H_sp - 1
        relative_coords[:, :, 1] += self.W_sp - 1
        relative_coords[:, :, 0] *= 2 * self.W_sp - 1
        relative_position_index = relative_coords.sum(-1)
        relative_position_bias = pos[relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
            self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 
        corr_map = corr_map + relative_position_bias.unsqueeze(0)

        # transformation
        v_drop = self.value_drop(v)
        x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1) 

        return x
    
    def channel_self_correlation(self, q, v):
        
        B, num_head, L, C = q.shape

        # apply single head strategy
        q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
        v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)

        # compute correlation map
        corr_map = (q.transpose(-2,-1) @ v) / L
        
        # transformation
        v_drop = self.value_drop(v)
        x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)

        return x

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, H, W, C)
        """
        xB,xN,xC = x.shape
        xH = xW = int(xN ** (1/2))
        qv = self.qv(x, (xH,xW))


        # qv splitting
        qv = qv.view(xB, xN, 2, self.num_heads, xC // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
        q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads

        # spatial self-correlation (S-SC)
        x_spatial = self.spatial_self_correlation(q, v)

        # channel self-correlation (C-SC)
        x_channel = self.channel_self_correlation(q, v)

        # spatial-channel information fusion
        x = torch.cat([x_spatial, x_channel], -1)
        x = self.proj_drop(self.proj(x))

        return x.view(xB,-1,xC)

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'



class MLLABlock(nn.Module):
    r""" MLLA Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., drop_path=0.,
                 act_layer=nn.GELU, base_win_size=(2,2),  norm_layer=nn.LayerNorm, **kwargs):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.window_size = input_resolution

        # base window size
        min_h = min(self.window_size[0], base_win_size[0])
        min_w = min(self.window_size[1], base_win_size[1])
        self.base_win_size = (min_h, min_w)

            
        self.cpe1 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.norm1 = norm_layer(dim)
        self.in_proj = nn.Linear(dim, dim)
        self.act_proj = nn.Linear(dim, dim)
        self.dwc = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.act = nn.SiLU()
        self.attn = LinearAttention(dim=dim, input_resolution=input_resolution, num_heads=num_heads, qkv_bias=qkv_bias, value_drop=drop_path, base_win_size=self.base_win_size)
        self.out_proj = nn.Linear(dim, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.cpe2 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)

    def check_image_size(self, x, win_size):
        x = x.permute(0,3,1,2).contiguous()
        _, _, h, w = x.size()
        mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
        mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]

        if mod_pad_h >= h or mod_pad_w >= w:
            pad_h, pad_w = h-1, w-1
            x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
        else:
            pad_h, pad_w = 0, 0
        
        mod_pad_h = mod_pad_h - pad_h
        mod_pad_w = mod_pad_w - pad_w
        
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
        x = x.permute(0,2,3,1).contiguous()
        return x
        
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x + self.cpe1(x.reshape(B, H, W, C).permute(0, 3, 1, 2)).flatten(2).permute(0, 2, 1)
        shortcut = x

        x = self.norm1(x)
        act_res = self.act(self.act_proj(x))
        x = self.in_proj(x).view(B, H, W, C)
        x = self.act(self.dwc(x.permute(0, 3, 1, 2))).permute(0, 2, 3, 1)


        
        x = self.check_image_size(x, self.base_win_size)
        _, H_pad, W_pad, _ = x.shape # shape after padding

        

        # Linear Attention
        x = self.attn(x.view(B, H_pad * W_pad, C)).view(B, H, W, C)

        # unpad
        x = x[:, :H, :W, :].contiguous().view(B, L, C)

        x = self.out_proj(x * act_res)
        x = shortcut + self.drop_path(x)
        x = x + self.cpe2(x.reshape(B, H, W, C).permute(0, 3, 1, 2)).flatten(2).permute(0, 2, 1)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"mlp_ratio={self.mlp_ratio}"


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
    """

    def __init__(self, input_resolution, dim, ratio=4.0):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        in_channels = dim
        out_channels = 2 * dim
        self.conv = nn.Sequential(
            ConvLayer(in_channels, int(out_channels * ratio), kernel_size=1, norm=None),
            ConvLayer(int(out_channels * ratio), int(out_channels * ratio), kernel_size=3, stride=2, padding=1, groups=int(out_channels * ratio), norm=None),
            ConvLayer(int(out_channels * ratio), out_channels, kernel_size=1, act_func=None)
        )

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
        x = self.conv(x.reshape(B, H, W, C).permute(0, 3, 1, 2)).flatten(2).permute(0, 2, 1)
        return x


class BasicLayer(nn.Module):
    """ A basic MLLA layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, mlp_ratio=4., qkv_bias=True, drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            MLLABlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads,
                      mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
                      drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"


class Stem(nn.Module):
    r""" Stem

    Args:
        img_size (int): Image size.  Default: 128.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
    """

    def __init__(self, img_size=128, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.conv1 = ConvLayer(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv2 = nn.Sequential(
            ConvLayer(embed_dim // 2, embed_dim // 2, kernel_size=3, stride=1, padding=1, bias=False),
            ConvLayer(embed_dim // 2, embed_dim // 2, kernel_size=3, stride=1, padding=1, bias=False, act_func=None)
        )
        self.conv3 = nn.Sequential(
            ConvLayer(embed_dim // 2, embed_dim * 4, kernel_size=3, stride=2, padding=1, bias=False),
            ConvLayer(embed_dim * 4, embed_dim, kernel_size=1, bias=False, act_func=None)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.conv1(x)
        x = self.conv2(x) + x
        x = self.conv3(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class CustomModel(nn.Module):
    """
    Args:
        img_size (int | tuple(int)): Input image size. Default 128
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each MLLA layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=128, patch_size=4, in_chans=3, num_classes=8,
                 embed_dim=64, depths=[ 2, 4, 8, 4 ], num_heads=[ 2, 4, 8, 16 ],
                 mlp_ratio=4., qkv_bias=True, drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, use_checkpoint=False, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        self.patch_embed = Stem(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, drop=drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

model = CustomModel(num_classes=10)  # Truyền số lớp vào mô hình
x = torch.rand(1,3,128,128)

start = time.time()
print(model(x).shape)
end = time.time()
print(end - start)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 10])
0.4086616039276123


In [3]:
# 1. Tiền xử lý dữ liệu
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset_root = '/kaggle/input/alphabet/Alphabet'
train_dir = os.path.join(dataset_root, 'train')
val_dir = os.path.join(dataset_root, 'val')
test_dir = os.path.join(dataset_root, 'Test_Alphabet')

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

classes = train_dataset.classes
num_classes = len(classes)



# Huấn luyện và đánh giá mô hình
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Khởi tạo mô hình
model = CustomModel(num_classes=num_classes).to(device)

# 3. Huấn luyện mô hình với early stopping
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=25, early_stopping_patience=5):
    best_val_acc = 0.0
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    history = []

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for inputs, labels in tqdm(train_loader, desc=f'Train Epoch {epoch+1}'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)

        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        val_losses.append(val_loss / total)
        val_acc = 100 * correct / total
        val_accs.append(val_acc)

        history.append({
            'epoch': epoch + 1,
            'train_loss': epoch_loss,
            'train_acc': epoch_acc,
            'val_loss': val_loss / total,
            'val_acc': val_acc
        })

        print(f"Epoch {epoch+1}: Train Acc = {epoch_acc:.2f}%, Val Acc = {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_custom_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"  No improvement. Patience: {patience_counter}/{early_stopping_patience}")
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break

    history_df = pd.DataFrame(history)
    history_df.to_csv('training_history.csv', index=False)

    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.legend(); plt.title('Loss'); plt.xlabel('Epoch')

    plt.subplot(1,2,2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.legend(); plt.title('Accuracy'); plt.xlabel('Epoch')

    plt.tight_layout()
    plt.savefig('train_val_metrics.png')
    plt.close()

# 4. Đánh giá mô hình
def evaluate_model(model, test_loader):
    model.load_state_dict(torch.load('best_custom_model.pth'))
    model.eval()
    all_preds, all_labels = [], []
    test_loss, correct, total = 0.0, 0, 0
    loss_fn = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Testing'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss = test_loss / total
    test_acc = 100 * correct / total
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=classes))

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(12,10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

# 5. Main
def main():
    print("Custom Model Structure:")
    print(model)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}\n")

    # Tính FLOPs
    from ptflops import get_model_complexity_info
    macs, params = get_model_complexity_info(model, (3, 128, 128), as_strings=True,
                                             print_per_layer_stat=False, verbose=False)
    print(f"FLOPs: {macs}")
    print(f"Params (from ptflops): {params}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100, early_stopping_patience=10)
    evaluate_model(model, test_loader)

if __name__ == '__main__':
    main()

Custom Model Structure:
CustomModel(
  (patch_embed): Stem(
    (conv1): ConvLayer(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (conv2): Sequential(
      (0): ConvLayer(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
      (1): ConvLayer(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (conv3): Sequential(
      (0): ConvLayer(
        (conv): Conv2d(32, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True,

Train Epoch 1: 100%|██████████| 684/684 [09:54<00:00,  1.15it/s]


Epoch 1: Train Acc = 66.87%, Val Acc = 90.91%


Train Epoch 2: 100%|██████████| 684/684 [07:53<00:00,  1.45it/s]


Epoch 2: Train Acc = 94.59%, Val Acc = 95.10%


Train Epoch 3: 100%|██████████| 684/684 [07:52<00:00,  1.45it/s]


Epoch 3: Train Acc = 97.64%, Val Acc = 96.05%


Train Epoch 4: 100%|██████████| 684/684 [07:50<00:00,  1.45it/s]


Epoch 4: Train Acc = 98.79%, Val Acc = 96.79%


Train Epoch 5: 100%|██████████| 684/684 [07:52<00:00,  1.45it/s]


Epoch 5: Train Acc = 99.14%, Val Acc = 95.06%
  No improvement. Patience: 1/10


Train Epoch 6: 100%|██████████| 684/684 [07:52<00:00,  1.45it/s]


Epoch 6: Train Acc = 99.36%, Val Acc = 96.95%


Train Epoch 7: 100%|██████████| 684/684 [07:58<00:00,  1.43it/s]


Epoch 7: Train Acc = 99.28%, Val Acc = 97.86%


Train Epoch 8: 100%|██████████| 684/684 [08:04<00:00,  1.41it/s]


Epoch 8: Train Acc = 99.37%, Val Acc = 97.70%
  No improvement. Patience: 1/10


Train Epoch 9: 100%|██████████| 684/684 [08:08<00:00,  1.40it/s]


Epoch 9: Train Acc = 99.69%, Val Acc = 97.41%
  No improvement. Patience: 2/10


Train Epoch 10: 100%|██████████| 684/684 [08:17<00:00,  1.37it/s]


Epoch 10: Train Acc = 99.62%, Val Acc = 98.15%


Train Epoch 11: 100%|██████████| 684/684 [08:19<00:00,  1.37it/s]


Epoch 11: Train Acc = 99.45%, Val Acc = 98.27%


Train Epoch 12: 100%|██████████| 684/684 [08:17<00:00,  1.38it/s]


Epoch 12: Train Acc = 99.65%, Val Acc = 97.74%
  No improvement. Patience: 1/10


Train Epoch 13: 100%|██████████| 684/684 [08:14<00:00,  1.38it/s]


Epoch 13: Train Acc = 99.76%, Val Acc = 97.24%
  No improvement. Patience: 2/10


Train Epoch 14: 100%|██████████| 684/684 [08:06<00:00,  1.41it/s]


Epoch 14: Train Acc = 99.54%, Val Acc = 98.27%
  No improvement. Patience: 3/10


Train Epoch 15: 100%|██████████| 684/684 [08:10<00:00,  1.40it/s]


Epoch 15: Train Acc = 99.95%, Val Acc = 98.72%


Train Epoch 16: 100%|██████████| 684/684 [08:06<00:00,  1.41it/s]


Epoch 16: Train Acc = 99.52%, Val Acc = 98.64%
  No improvement. Patience: 1/10


Train Epoch 17: 100%|██████████| 684/684 [08:30<00:00,  1.34it/s]


Epoch 17: Train Acc = 99.68%, Val Acc = 98.64%
  No improvement. Patience: 2/10


Train Epoch 18: 100%|██████████| 684/684 [08:29<00:00,  1.34it/s]


Epoch 18: Train Acc = 99.91%, Val Acc = 98.11%
  No improvement. Patience: 3/10


Train Epoch 19: 100%|██████████| 684/684 [08:13<00:00,  1.39it/s]


Epoch 19: Train Acc = 99.61%, Val Acc = 98.15%
  No improvement. Patience: 4/10


Train Epoch 20: 100%|██████████| 684/684 [08:09<00:00,  1.40it/s]


Epoch 20: Train Acc = 99.87%, Val Acc = 97.45%
  No improvement. Patience: 5/10


Train Epoch 21: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 21: Train Acc = 99.84%, Val Acc = 98.93%


Train Epoch 22: 100%|██████████| 684/684 [08:03<00:00,  1.41it/s]


Epoch 22: Train Acc = 99.85%, Val Acc = 97.16%
  No improvement. Patience: 1/10


Train Epoch 23: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 23: Train Acc = 99.82%, Val Acc = 98.85%
  No improvement. Patience: 2/10


Train Epoch 24: 100%|██████████| 684/684 [08:08<00:00,  1.40it/s]


Epoch 24: Train Acc = 99.89%, Val Acc = 98.44%
  No improvement. Patience: 3/10


Train Epoch 25: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 25: Train Acc = 99.84%, Val Acc = 98.64%
  No improvement. Patience: 4/10


Train Epoch 26: 100%|██████████| 684/684 [08:08<00:00,  1.40it/s]


Epoch 26: Train Acc = 99.77%, Val Acc = 98.97%


Train Epoch 27: 100%|██████████| 684/684 [08:08<00:00,  1.40it/s]


Epoch 27: Train Acc = 99.82%, Val Acc = 98.97%
  No improvement. Patience: 1/10


Train Epoch 28: 100%|██████████| 684/684 [08:05<00:00,  1.41it/s]


Epoch 28: Train Acc = 99.83%, Val Acc = 98.97%
  No improvement. Patience: 2/10


Train Epoch 29: 100%|██████████| 684/684 [08:03<00:00,  1.41it/s]


Epoch 29: Train Acc = 99.98%, Val Acc = 98.89%
  No improvement. Patience: 3/10


Train Epoch 30: 100%|██████████| 684/684 [08:05<00:00,  1.41it/s]


Epoch 30: Train Acc = 99.78%, Val Acc = 98.68%
  No improvement. Patience: 4/10


Train Epoch 31: 100%|██████████| 684/684 [08:05<00:00,  1.41it/s]


Epoch 31: Train Acc = 99.97%, Val Acc = 99.14%


Train Epoch 32: 100%|██████████| 684/684 [08:08<00:00,  1.40it/s]


Epoch 32: Train Acc = 100.00%, Val Acc = 99.26%


Train Epoch 33: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 33: Train Acc = 99.98%, Val Acc = 95.43%
  No improvement. Patience: 1/10


Train Epoch 34: 100%|██████████| 684/684 [08:12<00:00,  1.39it/s]


Epoch 34: Train Acc = 99.66%, Val Acc = 98.60%
  No improvement. Patience: 2/10


Train Epoch 35: 100%|██████████| 684/684 [08:10<00:00,  1.40it/s]


Epoch 35: Train Acc = 99.79%, Val Acc = 98.72%
  No improvement. Patience: 3/10


Train Epoch 36: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 36: Train Acc = 99.95%, Val Acc = 99.05%
  No improvement. Patience: 4/10


Train Epoch 37: 100%|██████████| 684/684 [08:07<00:00,  1.40it/s]


Epoch 37: Train Acc = 99.93%, Val Acc = 98.77%
  No improvement. Patience: 5/10


Train Epoch 38: 100%|██████████| 684/684 [08:13<00:00,  1.39it/s]


Epoch 38: Train Acc = 99.81%, Val Acc = 98.89%
  No improvement. Patience: 6/10


Train Epoch 39: 100%|██████████| 684/684 [08:14<00:00,  1.38it/s]


Epoch 39: Train Acc = 99.97%, Val Acc = 99.09%
  No improvement. Patience: 7/10


Train Epoch 40: 100%|██████████| 684/684 [08:12<00:00,  1.39it/s]


Epoch 40: Train Acc = 99.93%, Val Acc = 98.15%
  No improvement. Patience: 8/10


Train Epoch 41: 100%|██████████| 684/684 [08:21<00:00,  1.36it/s]


Epoch 41: Train Acc = 99.85%, Val Acc = 99.34%


Train Epoch 42: 100%|██████████| 684/684 [08:14<00:00,  1.38it/s]


Epoch 42: Train Acc = 99.93%, Val Acc = 98.48%
  No improvement. Patience: 1/10


Train Epoch 43: 100%|██████████| 684/684 [08:10<00:00,  1.39it/s]


Epoch 43: Train Acc = 99.89%, Val Acc = 99.26%
  No improvement. Patience: 2/10


Train Epoch 44: 100%|██████████| 684/684 [08:06<00:00,  1.41it/s]


Epoch 44: Train Acc = 99.89%, Val Acc = 98.72%
  No improvement. Patience: 3/10


Train Epoch 45: 100%|██████████| 684/684 [08:09<00:00,  1.40it/s]


Epoch 45: Train Acc = 99.92%, Val Acc = 99.26%
  No improvement. Patience: 4/10


Train Epoch 46: 100%|██████████| 684/684 [08:05<00:00,  1.41it/s]


Epoch 46: Train Acc = 100.00%, Val Acc = 99.34%
  No improvement. Patience: 5/10


Train Epoch 47: 100%|██████████| 684/684 [08:06<00:00,  1.41it/s]


Epoch 47: Train Acc = 100.00%, Val Acc = 99.42%


Train Epoch 48: 100%|██████████| 684/684 [08:04<00:00,  1.41it/s]


Epoch 48: Train Acc = 100.00%, Val Acc = 99.42%
  No improvement. Patience: 1/10


Train Epoch 49: 100%|██████████| 684/684 [08:05<00:00,  1.41it/s]


Epoch 49: Train Acc = 100.00%, Val Acc = 99.38%
  No improvement. Patience: 2/10


Train Epoch 50: 100%|██████████| 684/684 [08:02<00:00,  1.42it/s]


Epoch 50: Train Acc = 99.72%, Val Acc = 98.77%
  No improvement. Patience: 3/10


Train Epoch 51: 100%|██████████| 684/684 [08:03<00:00,  1.42it/s]


Epoch 51: Train Acc = 99.93%, Val Acc = 98.89%
  No improvement. Patience: 4/10


Train Epoch 52: 100%|██████████| 684/684 [08:03<00:00,  1.41it/s]


Epoch 52: Train Acc = 99.99%, Val Acc = 99.14%
  No improvement. Patience: 5/10


Train Epoch 53: 100%|██████████| 684/684 [08:09<00:00,  1.40it/s]


Epoch 53: Train Acc = 100.00%, Val Acc = 99.26%
  No improvement. Patience: 6/10


Train Epoch 54: 100%|██████████| 684/684 [08:11<00:00,  1.39it/s]


Epoch 54: Train Acc = 100.00%, Val Acc = 99.30%
  No improvement. Patience: 7/10


Train Epoch 55: 100%|██████████| 684/684 [08:13<00:00,  1.39it/s]


Epoch 55: Train Acc = 100.00%, Val Acc = 99.30%
  No improvement. Patience: 8/10


Train Epoch 56: 100%|██████████| 684/684 [08:15<00:00,  1.38it/s]


Epoch 56: Train Acc = 99.68%, Val Acc = 98.89%
  No improvement. Patience: 9/10


Train Epoch 57: 100%|██████████| 684/684 [08:15<00:00,  1.38it/s]


Epoch 57: Train Acc = 99.89%, Val Acc = 99.09%
  No improvement. Patience: 10/10
Early stopping triggered.


  model.load_state_dict(torch.load('best_custom_model.pth'))
Testing: 100%|██████████| 85/85 [00:58<00:00,  1.46it/s]


Test Loss: 0.0145, Test Accuracy: 99.67%

Classification Report:
              precision    recall  f1-score   support

           A       1.00      1.00      1.00       100
           B       1.00      1.00      1.00       100
       Blank       0.99      0.99      0.99       100
           C       1.00      1.00      1.00       100
           D       1.00      1.00      1.00       100
           E       1.00      1.00      1.00       100
           F       0.99      1.00      1.00       100
           G       1.00      1.00      1.00       100
           H       1.00      1.00      1.00       100
           I       1.00      1.00      1.00       100
           J       1.00      0.98      0.99       100
           K       1.00      1.00      1.00       100
           L       1.00      1.00      1.00       100
           M       1.00      1.00      1.00       100
           N       0.99      1.00      1.00       100
           O       1.00      1.00      1.00       100
           P    