In [1]:
!pip3 install triton torchinfo utils
!pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

Collecting git+https://github.com/sustcsonglin/flash-linear-attention
  Cloning https://github.com/sustcsonglin/flash-linear-attention to /tmp/pip-req-build-8hy0ez1k
  Running command git clone --filter=blob:none --quiet https://github.com/sustcsonglin/flash-linear-attention /tmp/pip-req-build-8hy0ez1k
  Resolved https://github.com/sustcsonglin/flash-linear-attention to commit 311d037b814facf5c67934a5d98d8cdb26ecad75
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
#TODO: 11/11 merge GLA into LeViT_impl and then test model

import torch
from einops import rearrange
import triton
import triton.language as tl

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

import itertools
import utils
import timm

from fla.ops.gla import fused_chunk_gla, chunk_gla, fused_recurrent_gla

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset

from tqdm import tqdm

from torchinfo import summary

#import os
#os.environ['TRITON_DISABLE_BF16'] = '1'

In [3]:
class ConvNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super(ConvNorm, self).__init__()
        self.linear = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size,
            stride=stride, padding=padding, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x)
        return x

In [4]:
class Stem16(nn.Module):
    def __init__(self):
        super(Stem16, self).__init__()
        self.conv1 = ConvNorm(3, 32)
        self.act1 = nn.Hardswish()
        self.conv2 = ConvNorm(32, 64)
        self.act2 = nn.Hardswish()
        self.conv3 = ConvNorm(64, 128)
        self.act3 = nn.Hardswish()
        self.conv4 = ConvNorm(128, 256)

    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.act2(self.conv2(x))
        x = self.act3(self.conv3(x))
        x = self.conv4(x)
        return x

In [5]:
class LinearNorm(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearNorm, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.bn = nn.BatchNorm1d(out_features)

    def forward(self, x):

        if x.dim() == 3:
            B, N, C = x.shape
            x = x.reshape(B * N, C)
            x = self.bn(self.linear(x))
            x = x.reshape(B, N, -1)
        else:
            x = self.bn(self.linear(x))
        return x

In [6]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads, attn_ratio=2):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        inner_dim = head_dim * num_heads * 3
        self.qkv = LinearNorm(dim, inner_dim)

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, dim)
        )

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

## GLA (Gated Linear Attention) Module

In [7]:
class GatedLinearAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.dim
        self.num_heads = config.num_heads

        self.gate_fn = nn.functional.silu
        assert config.use_gk and not config.use_gv, "Only use_gk is supported for simplicity."

        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim//2, bias=False)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim//2, bias=False)
        self.k_gate =  nn.Sequential(nn.Linear(self.embed_dim, 16, bias=False), nn.Linear(16, self.embed_dim // 2))

        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.g_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

        self.head_dim = self.embed_dim // self.num_heads
        self.key_dim = self.embed_dim // self.num_heads
        self.scaling = self.key_dim ** -0.5
        self.group_norm = nn.LayerNorm(self.head_dim, eps=1e-5, elementwise_affine=False)

        self.post_init()

    def post_init(self):
        nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
        nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
        if isinstance(self.k_gate, nn.Sequential):
            nn.init.xavier_uniform_(self.k_gate[0].weight, gain=2 ** -2.5)
            nn.init.xavier_uniform_(self.k_gate[1].weight, gain=2 ** -2.5)
        else:
            nn.init.xavier_uniform_(self.k_gate.weight, gain=2 ** -2.5)

    def forward(self, x, hidden_states=None):
        q = self.q_proj(x)
        k = self.k_proj(x) * self.scaling
        k_gate = self.k_gate(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        output, new_hidden_states = self.gated_linear_attention(q, k, v, k_gate, hidden_states=hidden_states)
        output = self.gate_fn(g) * output
        output = self.out_proj(output)
        #print(f"Inside GatedLinearAttention forward: new_hidden_states shape = {new_hidden_states.shape if new_hidden_states is not None else 'None'}")
        return output, new_hidden_states # this needs to be tensor, not tuple


    def gated_linear_attention(self, q, k, v, gk, normalizer=16, hidden_states=None):
        q = rearrange(q, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
        k = rearrange(k, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
        v = rearrange(v, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
        gk = rearrange(gk, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
        gk = F.logsigmoid(gk) / normalizer

        # for storing original dtype
        original_dtype = q.dtype

        if self.training:
            # cast inputs to float32 if needed
            if q.dtype == torch.bfloat16:
                q, k, v, gk = q.float(), k.float(), v.float(), gk.float()
            o, new_hidden_states = fused_chunk_gla(q, k, v, gk, initial_state=hidden_states, output_final_state=True)
            # cast back to origianl dtype if needed
            if o.dtype != original_dtype:
              o = o.type(original_dtype)

        else:
            o = fused_recurrent_gla(q, k, v, gk)

            new_hidden_states = None

        #print(f"Inside gated_linear_attention: new_hidden_states shape = {new_hidden_states.shape if new_hidden_states is not None else 'None'}")

        if isinstance(o, tuple):
          o = o[0]

        o = self.group_norm(o)
        o = rearrange(o, 'b h l d -> b l (h d)')
        return o, new_hidden_states

In [8]:
# define seperate config object for GLA input
class Config:
    def __init__(self, dim, num_heads, use_gk=True, use_gv=False):
        self.dim = dim
        self.num_heads = num_heads
        self.use_gk = use_gk
        self.use_gv = use_gv

In [9]:
class LevitMlp(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(LevitMlp, self).__init__()
        self.ln1 = LinearNorm(in_features, hidden_features)
        self.act = nn.Hardswish()
        self.drop = nn.Dropout(p=0.0, inplace=False)
        self.ln2 = LinearNorm(hidden_features, out_features)

    def forward(self, x):
        x = self.ln1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.ln2(x)
        return x

In [10]:
class LevitBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=2): # hidden state from GLA
        super(LevitBlock, self).__init__()
        #self.attn = Attention(dim, num_heads) # -> GLA, hidden state updates
        self.attn = GatedLinearAttention(Config(dim, num_heads))
        self.drop_path1 = nn.Identity()
        self.mlp = LevitMlp(dim, dim * mlp_ratio, dim)
        self.drop_path2 = nn.Identity()

    def forward(self, x):
        attn_output, hidden_state = self.attn(x) # GLA output is tuple (o, hidden_state)
        x = x + self.drop_path1(attn_output) # assign tuple
        x = x + self.drop_path2(self.mlp(x))
        return x, hidden_state

In [11]:
class AttentionDownsample(nn.Module):
    def __init__(self, dim, out_dim, num_heads, attn_ratio=2):
        super(AttentionDownsample, self).__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        inner_dim = dim * attn_ratio * num_heads
        self.kv = LinearNorm(dim, inner_dim)

        self.q = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=2, stride=2),
            nn.Flatten(start_dim=1)
        )

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, out_dim)
        )

    def forward(self, x):

        if isinstance(x, tuple):
          x, hidden_state = x
        else:
          hidden_state = None

        B, N, C = x.shape
        H = W = int(N ** 0.5)
        x = x.reshape(B, C, H, W)

        kv = self.kv(x.flatten(2).transpose(1, 2))
        q = self.q(x)

        q = q.reshape(B, -1, C)
        x = self.proj(q)
        return x

In [12]:
class LevitDownsample(nn.Module):
    def __init__(self, dim, out_dim, num_heads, attn_ratio=2):
        super(LevitDownsample, self).__init__()
        self.attn_downsample = AttentionDownsample(dim, out_dim, num_heads, attn_ratio)
        self.mlp = LevitMlp(out_dim, out_dim * attn_ratio, out_dim)
        self.drop_path = nn.Identity()

    def forward(self, x):

        if isinstance(x, tuple):
          x, hidden_state = x
        else:
          hidden_state = None

        x = self.attn_downsample(x)
        x = self.drop_path(self.mlp(x))

        if hidden_state is None:
          return x
        else:
          return x, hidden_state

In [13]:
class NormLinear(nn.Module):
    def __init__(self, in_features, out_features, dropout_prob=0.0):
        super(NormLinear, self).__init__()
        self.bn = nn.BatchNorm1d(in_features)
        self.drop = nn.Dropout(p=dropout_prob, inplace=False)
        self.linear = nn.Linear(in_features, out_features, bias=True)

    def forward(self, x):
        x = self.bn(x)
        x = self.drop(x)
        x = self.linear(x)
        return x


In [14]:
class LevitStage(nn.Module):
    def __init__(self, dim, out_dim, num_heads, num_blocks, downsample=True):
        super(LevitStage, self).__init__()
        self.downsample = LevitDownsample(dim, out_dim, num_heads) if downsample else nn.Identity()
        self.blocks = nn.Sequential(*[LevitBlock(out_dim, num_heads) for _ in range(num_blocks)])

    def forward(self, x):
        x = self.downsample(x)
        hidden_states = []
        for block in self.blocks:
            if isinstance(x, tuple):
                x, hidden_state = x
            else:
                hidden_state = None
            x, hidden_state_block = block(x) # use only x as input
            hidden_states.append(hidden_state_block)
            #print(f"Inside LevitStage: Block hidden_state shape = {hidden_state_block.shape if hidden_state_block is not None else 'None'}")  # 추가된 디버그 출력
        return x, hidden_states

In [15]:
class GLALeViT(nn.Module):
    def __init__(self, num_classes=37):
        super(GLALeViT, self).__init__()

        self.stem = Stem16()

        self.stages = nn.Sequential(
            LevitStage(dim=256, out_dim=256, num_heads=4, num_blocks=3, downsample=False),
            LevitStage(dim=256, out_dim=384, num_heads=6, num_blocks=3, downsample=True),
            LevitStage(dim=384, out_dim=512, num_heads=8, num_blocks=2, downsample=True)
        )

        self.head = NormLinear(in_features=512, out_features=num_classes, dropout_prob=0.0)
        self.head_dist = NormLinear(in_features=512, out_features=num_classes, dropout_prob=0.0)

    def forward(self, x):
        x = self.stem(x)
        B, C, H, W = x.shape
        x = x.view(B, C, -1).transpose(1, 2)

        x, all_hidden_states = self.stages(x)
        all_hidden_states = [state for state in all_hidden_states if state is not None]
        out = self.head(x.mean(dim=1))
        out_dist = self.head_dist(x.mean(dim=1))
        return out, all_hidden_states

        #x = self.stages(x)
        #out = self.head(x.mean(dim=1))
        #out_dist = self.head_dist(x.mean(dim=1))
        #return out


In [16]:
model = GLALeViT()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

batch_size = 32
learning_rate = 0.001
num_epochs = 50

In [17]:
model.to(device)

GLALeViT(
  (stem): Stem16(
    (conv1): ConvNorm(
      (linear): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act1): Hardswish()
    (conv2): ConvNorm(
      (linear): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act2): Hardswish()
    (conv3): ConvNorm(
      (linear): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act3): Hardswish()
    (conv4): ConvNorm(
      (linear): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stages): Sequential(
    (0):

In [18]:
from torchinfo import summary
print(summary(model, input_size=(32, 3, 224, 224)))

Layer (type:depth-idx)                                  Output Shape              Param #
GLALeViT                                                [32, 37]                  --
├─Stem16: 1-1                                           [32, 256, 14, 14]         --
│    └─ConvNorm: 2-1                                    [32, 32, 112, 112]        --
│    │    └─Conv2d: 3-1                                 [32, 32, 112, 112]        864
│    │    └─BatchNorm2d: 3-2                            [32, 32, 112, 112]        64
│    └─Hardswish: 2-2                                   [32, 32, 112, 112]        --
│    └─ConvNorm: 2-3                                    [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-3                                 [32, 64, 56, 56]          18,432
│    │    └─BatchNorm2d: 3-4                            [32, 64, 56, 56]          128
│    └─Hardswish: 2-4                                   [32, 64, 56, 56]          --
│    └─ConvNorm: 2-5                                  

In [19]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [20]:
trainval_data = datasets.OxfordIIITPet(root="data", split="trainval", target_types="category", download=True, transform=transform)
test_data = datasets.OxfordIIITPet(root="data", split="test", target_types="category", download=True, transform=transform)
combined_data = ConcatDataset([trainval_data, test_data])

train_size = int(0.7 * len(combined_data))
val_size = int(0.15 * len(combined_data))
test_size = len(combined_data) - train_size - val_size
train_data, val_data, test_data = random_split(combined_data, [train_size, val_size, test_size])

In [21]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

Train set size: 5144
Validation set size: 1102
Test set size: 1103


In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    hidden_states = None

    num_batches = len(train_loader)
    print_interval = num_batches // 10

    for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        # Check if 'outputs' is a tuple and extract the logits if necessary
        if isinstance(outputs, tuple):
            hidden_states = outputs[1]
            outputs = outputs[0]  # Assuming the first element of the tuple contains the logits

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (batch_idx + 1) % print_interval == 0 and hidden_states is not None:
            if isinstance(hidden_states, list):
                print(f"{(batch_idx + 1) / num_batches * 100:.0f}% 진행 - hidden states element shapes: {[h.shape for h in hidden_states]}")

            else:
                print(f"{(batch_idx + 1) / num_batches * 100:.0f}% 진행 - hidden states shape: {hidden_states.shape}")
    epoch_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    if torch.isnan(torch.tensor(epoch_loss)):
      print(f"Epoch {epoch+1}: NaN epoch loss detected!")
      print(f"running loss: {running_loss}, train_loader length: {len(train_loader)}")
    print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
    if hidden_states is not None:
        if isinstance(hidden_states, list):

            print(f"Epoch 종료 시 hidden states length: {len(hidden_states)}")

            hidden_states = torch.stack(hidden_states)
            print(f"Epoch 종료 시 hidden states shape: {hidden_states.shape}")

        else:
            print(f"Epoch 종료 시 hidden states shape: {hidden_states.shape}")


In [24]:
def evaluate(model, data_loader, criterion, device, phase="Validation"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"{phase}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            # Check if 'outputs' is a tuple and extract the logits if necessary
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # Assuming the first element of the tuple contains the logits

            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"{phase} Loss: {epoch_loss:.4f}, {phase} Accuracy: {accuracy:.2f}%")

In [25]:
def measure_inference_time(model, data_loader, device):
    model.eval()
    times = []

    with torch.no_grad():
        for inputs, _ in data_loader:
            inputs = inputs.to(device)
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)

            start_time.record()
            _ = model(inputs)  # inference 수행
            end_time.record()

            # 시간 측정
            torch.cuda.synchronize()  # CUDA에서 모든 커널이 완료될 때까지 대기
            elapsed_time = start_time.elapsed_time(end_time)  # 밀리초 단위로 반환
            times.append(elapsed_time)

    # 통계량 계산
    times_np = np.array(times)
    total_inferences = len(times_np)
    avg_time = np.mean(times_np)
    std_dev = np.std(times_np)
    max_time = np.max(times_np)
    min_time = np.min(times_np)

    # 결과 출력
    print(f"Inference Time Measurement Results:")
    print(f"Total Inferences: {total_inferences}")
    print(f"Average Time: {avg_time:.2f} ms")
    print(f"Standard Deviation: {std_dev:.2f} ms")
    print(f"Maximum Time: {max_time:.2f} ms")
    print(f"Minimum Time: {min_time:.2f} ms")

    return times

In [36]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    train(model, train_loader, criterion, optimizer, device)
    evaluate(model, val_loader, criterion, device, phase="Validation")


Epoch 1/50


Training:  11%|█         | 17/161 [00:03<00:28,  5.00it/s]

10% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  20%|█▉        | 32/161 [00:06<00:26,  4.89it/s]

20% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  30%|██▉       | 48/161 [00:10<00:23,  4.87it/s]

30% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  40%|████      | 65/161 [00:13<00:19,  4.91it/s]

40% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  50%|████▉     | 80/161 [00:16<00:16,  4.91it/s]

50% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  60%|█████▉    | 96/161 [00:20<00:14,  4.47it/s]

60% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  70%|██████▉   | 112/161 [00:23<00:10,  4.84it/s]

70% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  80%|███████▉  | 128/161 [00:26<00:06,  4.96it/s]

80% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  90%|█████████ | 145/161 [00:30<00:03,  4.77it/s]

89% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training: 100%|██████████| 161/161 [00:33<00:00,  4.79it/s]


99% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]
Epoch 1 Loss: 3.5110
Epoch 종료 시 hidden states length: 2
Epoch 종료 시 hidden states shape: torch.Size([2, 24, 8, 32, 64])


Validation: 100%|██████████| 35/35 [00:05<00:00,  6.09it/s]


Validation Loss: 3.4959, Validation Accuracy: 4.54%

Epoch 2/50


Training:  11%|█         | 17/161 [00:03<00:30,  4.77it/s]

10% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  20%|█▉        | 32/161 [00:06<00:25,  5.00it/s]

20% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  30%|██▉       | 48/161 [00:09<00:23,  4.84it/s]

30% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  40%|███▉      | 64/161 [00:13<00:18,  5.16it/s]

40% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  50%|████▉     | 80/161 [00:16<00:16,  5.00it/s]

50% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  60%|██████    | 97/161 [00:19<00:13,  4.90it/s]

60% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  70%|██████▉   | 112/161 [00:22<00:09,  4.99it/s]

70% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  80%|███████▉  | 128/161 [00:26<00:06,  4.73it/s]

80% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  89%|████████▉ | 144/161 [00:29<00:03,  4.88it/s]

89% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training: 100%|██████████| 161/161 [00:32<00:00,  4.89it/s]


99% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]
Epoch 2 Loss: 3.5054
Epoch 종료 시 hidden states length: 2
Epoch 종료 시 hidden states shape: torch.Size([2, 24, 8, 32, 64])


Validation: 100%|██████████| 35/35 [00:05<00:00,  5.98it/s]


Validation Loss: 3.6071, Validation Accuracy: 4.26%

Epoch 3/50


Training:  10%|▉         | 16/161 [00:03<00:29,  4.87it/s]

10% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  20%|█▉        | 32/161 [00:06<00:25,  5.10it/s]

20% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  30%|███       | 49/161 [00:09<00:21,  5.12it/s]

30% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  40%|████      | 65/161 [00:13<00:18,  5.06it/s]

40% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  50%|████▉     | 80/161 [00:16<00:15,  5.15it/s]

50% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  60%|█████▉    | 96/161 [00:19<00:12,  5.03it/s]

60% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  70%|███████   | 113/161 [00:22<00:10,  4.64it/s]

70% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  80%|███████▉  | 128/161 [00:25<00:06,  5.07it/s]

80% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  90%|█████████ | 145/161 [00:29<00:03,  4.89it/s]

89% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training: 100%|██████████| 161/161 [00:32<00:00,  4.94it/s]


99% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]
Epoch 3 Loss: 3.4997
Epoch 종료 시 hidden states length: 2
Epoch 종료 시 hidden states shape: torch.Size([2, 24, 8, 32, 64])


Validation: 100%|██████████| 35/35 [00:05<00:00,  5.91it/s]


Validation Loss: 3.5063, Validation Accuracy: 4.63%

Epoch 4/50


Training:  10%|▉         | 16/161 [00:03<00:28,  5.07it/s]

10% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  20%|█▉        | 32/161 [00:06<00:26,  4.87it/s]

20% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  30%|██▉       | 48/161 [00:09<00:22,  4.94it/s]

30% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  40%|███▉      | 64/161 [00:13<00:20,  4.80it/s]

40% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  50%|████▉     | 80/161 [00:16<00:16,  4.99it/s]

50% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  60%|█████▉    | 96/161 [00:19<00:13,  4.74it/s]

60% 진행 - hidden states element shapes: [torch.Size([32, 8, 32, 64]), torch.Size([32, 8, 32, 64])]


Training:  60%|██████    | 97/161 [00:20<00:13,  4.82it/s]


KeyboardInterrupt: 

In [37]:
print("\nFinal Test Evaluation")
evaluate(model, test_loader, criterion, device, phase="Test")


Final Test Evaluation


Test: 100%|██████████| 35/35 [00:06<00:00,  5.26it/s]

Test Loss: 3.5360, Test Accuracy: 5.08%





In [38]:
times = measure_inference_time(model, test_loader, device)

Inference Time Measurement Results:
Total Inferences: 35
Average Time: 14.30 ms
Standard Deviation: 0.57 ms
Maximum Time: 16.06 ms
Minimum Time: 13.80 ms


In [39]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# Assuming 'device' is defined and set to your CUDA device (e.g., 'cuda:0')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Sample input for testing, moved to the device
sample_input = torch.rand(1, 3, 224, 224).to(device)  # Move input to device

# Initialize the model and move it to the device
model = GLALeViT(num_classes=37).to(device)  # Move model to device

# Set the model to evaluation mode
model.eval() # This line is added to set the model to evaluation mode

# Forward pass to get the output and hidden states
for batch_idx, (inputs, labels) in enumerate(train_loader):
    inputs = inputs.to(device)  # 데이터를 장치로 이동 (GPU 또는 CPU)

    # 모델의 forward pass를 수행하여 output과 hidden states를 얻음
    output, hidden_states = model(inputs)

    # 각 배치에 대해 hidden state를 출력하고 시각화합니다.
    print(f"\nBatch {batch_idx+1} - Hidden States:")
    for i, hidden_state in enumerate(hidden_states):
        if hidden_state is not None:
            print(f"Hidden state {i+1} shape: {hidden_state.shape}")

            # 첫 번째 이미지를 시각화 (배치에서 첫 번째 예제)
            hidden_state_np = hidden_state[0].detach().cpu().numpy()

            if hidden_state_np.ndim > 1:
                plt.imshow(hidden_state_np, cmap='gray')
                plt.title(f'Hidden State {i+1} - Batch {batch_idx+1}')
                plt.colorbar()
                plt.show()
        else:
            print(f"Hidden state {i+1} is None.")


Batch 1 - Hidden States:

Batch 2 - Hidden States:

Batch 3 - Hidden States:

Batch 4 - Hidden States:

Batch 5 - Hidden States:

Batch 6 - Hidden States:

Batch 7 - Hidden States:

Batch 8 - Hidden States:

Batch 9 - Hidden States:

Batch 10 - Hidden States:

Batch 11 - Hidden States:

Batch 12 - Hidden States:

Batch 13 - Hidden States:

Batch 14 - Hidden States:

Batch 15 - Hidden States:

Batch 16 - Hidden States:

Batch 17 - Hidden States:

Batch 18 - Hidden States:

Batch 19 - Hidden States:

Batch 20 - Hidden States:

Batch 21 - Hidden States:

Batch 22 - Hidden States:

Batch 23 - Hidden States:

Batch 24 - Hidden States:

Batch 25 - Hidden States:

Batch 26 - Hidden States:

Batch 27 - Hidden States:

Batch 28 - Hidden States:

Batch 29 - Hidden States:

Batch 30 - Hidden States:

Batch 31 - Hidden States:

Batch 32 - Hidden States:

Batch 33 - Hidden States:

Batch 34 - Hidden States:

Batch 35 - Hidden States:

Batch 36 - Hidden States:

Batch 37 - Hidden States:

Batch 38 

KeyboardInterrupt: 