In [55]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
import sklearn
from torch.optim.lr_scheduler import OneCycleLR, MultiStepLR
import os
import time
import pickle
import math
import random
import sys
import cv2
import gc
import glob
import datetime
import json
import pyarrow.parquet as pq
from sklearn.model_selection import KFold
from torch.utils.data import Subset
from torch.optim import Adam, AdamW, SGD, RMSprop, Adamax, Adadelta, Adagrad
from torch.cuda import amp
from torchvision import transforms
from sklearn.model_selection import train_test_split

In [56]:
def get_device_strategy(device='GPU'):
    IS_TPU = False

    if device == 'TPU':
        # Note: TPU support in PyTorch requires torch_xla library, typically used on Google Cloud Platform.
        try:
            import torch_xla.core.xla_model as xm
            device = xm.xla_device()
            IS_TPU = True
            print("Using TPU")
        except ImportError:
            raise ImportError("TPU support requires the torch_xla library.")
    
    elif device == 'GPU' or device == 'CPU' or device == 'MPS':
        if torch.cuda.is_available() and device == 'GPU':
            ngpu = torch.cuda.device_count()
            if ngpu > 1:
                print("Using multi GPU")
                device = torch.device('cuda')
            elif ngpu == 1:
                print("Using single GPU")
                device = torch.device('cuda')
            else:
                print("Using CPU")
                device = torch.device('cpu')
        elif device == 'MPS' and torch.backends.mps.is_available():
            print("Using MPS")
            device = torch.device('mps')
        else:
            print("Using CPU")
            device = torch.device('cpu')

    if device.type == 'cuda':
        ngpu = torch.cuda.device_count()
        print("Num GPUs Available: ", ngpu)
    elif device.type == 'mps':
        ngpu = 1
        print("Num MPS Devices Available: 1")
    else:
        ngpu = 0

    REPLICAS = ngpu if ngpu > 0 else 1
    print(f'REPLICAS: {REPLICAS}')

    return device, REPLICAS, IS_TPU

device, N_REPLICAS, IS_TPU = get_device_strategy(device='GPU')

# Output the device details
print(f'Using device: {device}')
print(f'Number of replicas: {N_REPLICAS}')
print(f'Is TPU: {IS_TPU}')

Using CPU
REPLICAS: 1
Using device: cpu
Number of replicas: 1
Is TPU: False


In [57]:
base_dir = os.getcwd()

In [58]:
train_df = pd.read_csv(os.path.join(base_dir, 'asl-signs', 'train.csv'))
paths = train_df['path'].values
paths = [os.path.join(base_dir, 'asl-signs', path) for path in paths]
signs = train_df['sign'].values
with open(os.path.join(base_dir, 'asl-signs', 'sign_to_prediction_index_map.json'), 'r') as f:
    sign_to_prediction_index_map = json.load(f)

def labels_to_ids(labels, mapping):
    return [mapping[label] for label in labels]

ids = labels_to_ids(signs, sign_to_prediction_index_map)

In [59]:
# Constants
ROWS_PER_FRAME = 543
MAX_LEN = 384
NUM_CLASSES = 250
PAD = -100.0
NOSE = [1, 2, 98, 327]
LNOSE = [98]
RNOSE = [327]
LIP = [
    0, 61, 185, 40, 39, 37, 267, 269, 270, 409,
    291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
    78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
    95, 88, 178, 87, 14, 317, 402, 318, 324, 308,
]
LLIP = [84, 181, 91, 146, 61, 185, 40, 39, 37, 87, 178, 88, 95, 78, 191, 80, 81, 82]
RLIP = [314, 405, 321, 375, 291, 409, 270, 269, 267, 317, 402, 318, 324, 308, 415, 310, 311, 312]
POSE = [500, 502, 504, 501, 503, 505, 512, 513]
LPOSE = [513, 505, 503, 501]
RPOSE = [512, 504, 502, 500]
REYE = [
    33, 7, 163, 144, 145, 153, 154, 155, 133,
    246, 161, 160, 159, 158, 157, 173,
]
LEYE = [
    263, 249, 390, 373, 374, 380, 381, 382, 362,
    466, 388, 387, 386, 385, 384, 398,
]
LHAND = np.arange(468, 489).tolist()
RHAND = np.arange(522, 543).tolist()
POINT_LANDMARKS = LIP + LHAND + RHAND + NOSE + REYE + LEYE
NUM_NODES = len(POINT_LANDMARKS)

In [60]:
def flip_lr(x):
    # Assuming x coordinates are normalized [0,1] and x has shape [num_frames, num_landmarks, 3]
    x[..., 0] = 1 - x[..., 0]  # Flip the x-coordinate

    indices_to_swap = {
        tuple(LHAND): RHAND,
        tuple(RHAND): LHAND,
        tuple(LLIP): RLIP,
        tuple(RLIP): LLIP,
        tuple(LPOSE): RPOSE,
        tuple(RPOSE): LPOSE,
        tuple(LEYE): REYE,
        tuple(REYE): LEYE,
        tuple(LNOSE): RNOSE,
        tuple(RNOSE): LNOSE,
    }
    
    num_landmarks = x.shape[1]
    
    for k, v in indices_to_swap.items():
        # Ensure all indices are within bounds before swapping
        if max(max(k), max(v)) < num_landmarks:
            temp = x[:, k].clone()
            x[:, k] = x[:, v]
            x[:, v] = temp
    
    return x

def resample(x, rate_range=(0.8, 1.2)):
    original_length = x.size(0)
    new_length = int(original_length * torch.empty(1).uniform_(*rate_range).item())
    indices = torch.linspace(0, original_length - 1, new_length).long()
    return x[indices]

import torchvision.transforms.functional as F

def spatial_random_affine(x, scale=(0.8, 1.2), shear=(-15, 15), shift=(-0.1, 0.1), degree=(-30, 30)):
    # Apply random scaling
    scale_factor = torch.empty(1).uniform_(*scale).item()
    x[:, :, :2] *= scale_factor  # Scale x, y coordinates

    # Apply random rotation
    theta = torch.empty(1).uniform_(*degree) * (math.pi / 180)  # Convert degrees to radians
    rotation_matrix = torch.tensor([
        [torch.cos(theta), -torch.sin(theta)],
        [torch.sin(theta), torch.cos(theta)]
    ])
    x[:, :, :2] = torch.matmul(x[:, :, :2], rotation_matrix)  # Apply rotation

    # Apply random translation
    translation = torch.empty(2).uniform_(*shift) * x.shape[1]  # Adjust translation range based on shape
    x[:, :, :2] += translation

    return x

def temporal_crop(x, length=MAX_LEN):
    l = x.size(0)
    if l < length:
        # If the sequence is shorter than the desired length, return as is
        return x
    # Calculate a valid offset to avoid out-of-bounds indexing
    offset = torch.randint(0, l - length + 1, (1,)).item()
    return x[offset:offset+length]

def temporal_mask(x, size_range=(0.2, 0.4), mask_value=float('nan')):
    length = x.size(0)
    mask_size = int(torch.empty(1).uniform_(*size_range).item() * length)
    start = torch.randint(0, length - mask_size + 1, (1,)).item()
    x[start:start+mask_size] = mask_value
    return x

def spatial_mask(x, size_range=(0.2, 0.4), mask_value=float('nan')):
    num_landmarks = x.size(1)
    num_to_mask = int(num_landmarks * torch.empty(1).uniform_(*size_range).item())
    indices_to_mask = torch.randperm(num_landmarks)[:num_to_mask]
    x[:, indices_to_mask, :] = mask_value
    return x

In [61]:
augmentations = [flip_lr, temporal_mask, resample, spatial_random_affine, temporal_crop, spatial_mask]

In [62]:
def filter_nans_torch(tensor):
    # Assumes tensor shape [n_frames, n_landmarks, 3]
    # Check for any NaN values across the xyz dimensions of each landmark
    is_not_nan = ~torch.isnan(tensor).any(dim=2)  # Check along the xyz dimension
    valid_frames = is_not_nan.all(dim=1)  # Check that all landmarks in a frame are not NaN
    return tensor[valid_frames]

class ParquetDataset(Dataset):
    def __init__(self, parquet_paths, id_labels, augmentations=None):
        self.parquet_paths = parquet_paths
        self.id_labels = id_labels
        self.augmentations = augmentations

    def __len__(self):
        return len(self.parquet_paths)

    def __getitem__(self, idx):
        parquet_path = self.parquet_paths[idx]
        label = self.id_labels[idx]
        data = pq.read_table(parquet_path).to_pandas()

        # Filter relevant landmarks
        filtered_data = data[data['landmark_index'].isin(POINT_LANDMARKS)]

        # Extract the columns and group by frame
        grouped = filtered_data.groupby('frame')

        frames_list = []
        x_list = []
        y_list = []
        z_list = []

        for frame, group in grouped:
            frames_list.append(frame)
            x_list.append(group['x'].values)
            y_list.append(group['y'].values)
            z_list.append(group['z'].values)
        
        # Convert to a single combined numpy array
        combined_array = np.zeros((len(x_list), NUM_NODES, 3))
        for i in range(len(x_list)):
            combined_array[i, :len(x_list[i]), 0] = x_list[i]
            combined_array[i, :len(y_list[i]), 1] = y_list[i]
            combined_array[i, :len(z_list[i]), 2] = z_list[i]

        # Convert to tensor before filtering NaNs
        combined_tensor = torch.tensor(combined_array, dtype=torch.float32)
        filtered_tensor = filter_nans_torch(combined_tensor)

        if self.augmentations:
            for aug in self.augmentations:
                filtered_tensor = aug(filtered_tensor)

        filtered_tensor = filter_nans_torch(combined_tensor)

        # Ensure the filtered array is padded to MAX_LEN
        if filtered_tensor.shape[0] < MAX_LEN:
            padded_tensor = torch.full((MAX_LEN, NUM_NODES, 3), PAD, dtype=torch.float32)
            padded_tensor[:filtered_tensor.shape[0], :, :] = filtered_tensor
        else:
            padded_tensor = filtered_tensor[:MAX_LEN, :, :]


        return padded_tensor, label

def collate_fn(batch):
    data, labels = zip(*batch)
    data = torch.stack(data)  # Ensure all data tensors have the same size
    labels = torch.tensor(labels)
    return data, labels


In [63]:
batch_size = 64
dataset = ParquetDataset(paths, ids, augmentations)
testloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

for i, batch in enumerate(testloader):
    data, labels = batch
    if torch.isnan(data).any():
        print(f"NaN values found in batch {i+1}")
    else:
        print(f"No NaN values in batch {i+1}")
    print(f"Batch {i+1}")
    print(f"Data shape: {data.shape}")
    if i == 4:  # Print only the first 5 batches
        break

No NaN values in batch 1
Batch 1
Data shape: torch.Size([64, 384, 118, 3])
No NaN values in batch 2
Batch 2
Data shape: torch.Size([64, 384, 118, 3])
No NaN values in batch 3
Batch 3
Data shape: torch.Size([64, 384, 118, 3])
No NaN values in batch 4
Batch 4
Data shape: torch.Size([64, 384, 118, 3])
No NaN values in batch 5
Batch 5
Data shape: torch.Size([64, 384, 118, 3])


In [67]:
class ECA(nn.Module):
    def __init__(self, kernel_size=5):
        super(ECA, self).__init__()
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False)

    def forward(self, x):
        B, C, T = x.shape  # Batch size, Channels, Time steps
        nn = F.adaptive_avg_pool1d(x, 1).view(B, C, 1)
        nn = self.conv(nn.transpose(1, 2)).transpose(1, 2)  # Transpose to match Conv1d expected input
        nn = torch.sigmoid(nn)
        return x * nn

class LateDropout(nn.Module):
    def __init__(self, rate, start_step=0):
        super(LateDropout, self).__init__()
        self.rate = rate
        self.start_step = start_step
        self.step = 0

    def forward(self, x, training=False):
        if training and self.step >= self.start_step:
            x = F.dropout(x, p=self.rate, training=training)
        if training:
            self.step += 1
        return x

class CausalDWConv1D(nn.Module):
    def __init__(self, in_channels, kernel_size=17, dilation_rate=1, use_bias=False):
        super(CausalDWConv1D, self).__init__()
        self.causal_pad = nn.ConstantPad1d((dilation_rate * (kernel_size - 1), 0), 0)
        self.dw_conv = nn.Conv1d(in_channels, in_channels, kernel_size, dilation=dilation_rate, bias=use_bias, groups=in_channels)

    def forward(self, x):
        x = self.causal_pad(x)
        x = self.dw_conv(x)
        return x

import torch.nn.functional as F  

class Conv1DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation_rate=1, drop_rate=0.0, expand_ratio=2, activation='relu'):
        super(Conv1DBlock, self).__init__()
        self.expand = nn.Conv1d(in_channels, in_channels * expand_ratio, kernel_size=1, bias=True)
        self.dw_conv = CausalDWConv1D(in_channels * expand_ratio, kernel_size, dilation_rate)
        self.bn = nn.BatchNorm1d(in_channels * expand_ratio)
        self.eca = ECA(kernel_size)
        self.project = nn.Conv1d(in_channels * expand_ratio, out_channels, kernel_size=1, bias=True)
        self.drop = nn.Dropout(drop_rate)
        self.activation = getattr(F, activation)

    def forward(self, x):
        residual = x
        x = self.activation(self.expand(x))
        x = self.dw_conv(x)
        x = self.bn(x)
        x = self.eca(x)
        x = self.project(x)
        if self.drop.p > 0:
            x = self.drop(x)
        if residual.shape == x.shape:
            x += residual
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim=256, num_heads=4, dropout=0):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.scale = dim ** -0.5
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim=256, num_heads=4, expand=4, attn_dropout=0.2, drop_rate=0.2, activation='relu'):
        super(TransformerBlock, self).__init__()
        self.attn = MultiHeadSelfAttention(dim, num_heads, attn_dropout)
        self.drop1 = nn.Dropout(drop_rate)
        self.norm1 = nn.LayerNorm(dim)  # Using LayerNorm instead of BatchNorm
        self.expand = nn.Linear(dim, dim * expand, bias=False)
        self.activation = getattr(F, activation)
        self.project = nn.Linear(dim * expand, dim, bias=False)
        self.drop2 = nn.Dropout(drop_rate)
        self.norm2 = nn.LayerNorm(dim)  # Using LayerNorm instead of BatchNorm

    def forward(self, x):
        attn_out = self.attn(self.norm1(x))
        x = x + self.drop1(attn_out)
        residual = x
        x = self.activation(self.expand(self.norm2(x)))
        x = self.project(x)
        x = residual + self.drop2(x)
        return x

In [68]:
class Model(nn.Module):
    def __init__(self, num_frames=384, num_landmarks=118, num_classes=250, dim=192, dropout_step=0):
        super(Model, self).__init__()
        self.masking_value = -100.0
        self.embedding = nn.Linear(num_landmarks * 3, dim, bias=False)
        self.bn = nn.BatchNorm1d(dim)
        self.conv1 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.conv2 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.conv3 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.trans1 = TransformerBlock(dim, expand=2)
        self.conv4 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.conv5 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.conv6 = Conv1DBlock(dim, dim, 17, drop_rate=0.2)
        self.trans2 = TransformerBlock(dim, expand=2)
        self.fc = nn.Linear(dim, num_classes)
        self.late_dropout = LateDropout(0.8, start_step=dropout_step)
        
    def forward(self, x, training=False):
        B, T, L, C = x.shape  # [batch_size, num_frames, num_landmarks, 3]
        x = x.reshape(B, T, L * C)  # Flatten landmarks and coordinates into the feature dimension
        mask = (x != self.masking_value).float()
        x = self.embedding(x)
        x = self.bn(x.transpose(1, 2)).transpose(1, 2)  # Apply batch norm on the feature dimension
        x = x.transpose(1, 2)  # [batch_size, dim, num_frames] for Conv1D
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.trans1(x.transpose(1, 2)).transpose(1, 2)  # [batch_size, num_frames, dim]
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.trans2(x.transpose(1, 2)).transpose(1, 2)  # [batch_size, num_frames, dim]
        x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)  # [batch_size, dim]
        x = self.late_dropout(x, training=training)
        x = self.fc(x)
        return x

In [70]:
class CFG:
    n_splits = 5
    save_output = True
    output_dir = '.'

    seed = 42
    verbose = 2 #0) silent 1) progress bar 2) one line per epoch

    max_len = 384
    replicas = 8
    lr = 5e-4 * replicas
    weight_decay = 0.1
    lr_min = 1e-6
    epochs = 300 #400
    warmup = 0
    batch_size = 64 * replicas
    snapshot_epochs = []
    swa_epochs = [] #list(range(epoch//2,epoch+1))

    fp16 = True
    fgm = False
    awp = True
    awp_lambda = 0.2
    awp_start_epoch = 15
    dropout_start_epoch = 15
    resume = 0
    decay_type = 'cosine'
    dim = 192
    comment = f'islr-fp16-192-8-seed{seed}'

def get_dataloaders(dataset, num_folds, batch_size, test_size_ratio=0.1):
    indices = np.arange(len(dataset))
    test_size = int(test_size_ratio * len(dataset))
    test_indices = np.random.choice(indices, test_size, replace=False)

    train_val_indices = np.setdiff1d(indices, test_indices)
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

    fold_data = []

    for train_idx, val_idx in kf.split(train_val_indices):
        train_idx = train_val_indices[train_idx]
        val_idx = train_val_indices[val_idx]
        
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        
        fold_data.append((train_loader, val_loader))
    
    test_subset = Subset(dataset, test_indices)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)
    
    return fold_data, test_loader

# Parameters
num_folds = 5
batch_size = 64 * 8

# Split the dataset indices for k-fold cross-validation
indices = np.arange(len(dataset))
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

fold_data, test_loader = get_dataloaders(dataset, num_folds, batch_size)

for fold, (train_loader, val_loader) in enumerate(fold_data):
    print(f"Fold {fold + 1}:")
    print(f"Train loader has {len(train_loader.dataset)} samples")
    print(f"Val loader has {len(val_loader.dataset)} samples")
print(f"Test loader has {len(test_loader.dataset)} samples")

def train_fold(cfg, fold, train_loader, val_loader, device='cuda'):
    model = Model(cfg.dim).to(device)
    if cfg.fp16:
        scaler = amp.GradScaler()

    optimizer = Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = OneCycleLR(optimizer, max_lr=cfg.lr, epochs=cfg.epochs, steps_per_epoch=len(train_loader))

    best_loss = float('inf')
    for epoch in range(cfg.epochs):
        model.train()
        running_loss = 0.0
        iteration = 0
        print(f"Starting training for epoch {epoch+1}/{cfg.epochs}, fold {fold+1}")
        
        for data, target in train_loader:
            iteration += 1
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            
            with amp.autocast(enabled=cfg.fp16):
                output = model(data)
                loss = torch.nn.functional.cross_entropy(output, target)
            
            if cfg.fp16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            if iteration % 100 == 0:
                print(f"    Epoch {epoch+1}, Fold {fold+1}, Iteration {iteration}, Partial Loss: {running_loss / iteration:.4f}")

        scheduler.step()
        print(f"Training completed for epoch {epoch+1}/{cfg.epochs}, fold {fold+1}. Average Loss: {running_loss / len(train_loader):.4f}")
        
        # Validation
        model.eval()
        validation_loss = 0.0
        val_iteration = 0
        print(f"Starting validation for epoch {epoch+1}/{cfg.epochs}, fold {fold+1}")
        
        with torch.no_grad():
            for data, target in val_loader:
                val_iteration += 1
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = torch.nn.functional.cross_entropy(output, target)
                validation_loss += loss.item()
                
                if val_iteration % 100 == 0:
                    print(f"    Epoch {epoch+1}, Fold {fold+1}, Validation Iteration {val_iteration}, Partial Val Loss: {validation_loss / val_iteration:.4f}")

        avg_val_loss = validation_loss / len(val_loader)
        print(f"Validation completed for epoch {epoch+1}/{cfg.epochs}, fold {fold+1}. Average Validation Loss: {avg_val_loss:.4f}")

        # Save the best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(model.state_dict(), f'{cfg.output_dir}/{cfg.comment}-fold{fold}-best.pth')
            print(f"New best model saved with validation loss {avg_val_loss:.4f}")

    return model


def train_folds(cfg):
    # Assuming `get_dataloaders` returns all the loaders including the test loader
    fold_data, test_loader = get_dataloaders(dataset, cfg.n_splits, cfg.batch_size)

    for fold, (train_loader, val_loader) in enumerate(fold_data):
        print(f"Training fold {fold + 1}")
        # Assuming `train_fold` is adapted to take loaders directly
        model = train_fold(cfg, fold, train_loader, val_loader)
        print(f"Completed training for fold {fold + 1}")
    
    return model
        
cfg = CFG()
model = train_folds(cfg)

Fold 1:
Train loader has 68024 samples
Val loader has 17006 samples
Fold 2:
Train loader has 68024 samples
Val loader has 17006 samples
Fold 3:
Train loader has 68024 samples
Val loader has 17006 samples
Fold 4:
Train loader has 68024 samples
Val loader has 17006 samples
Fold 5:
Train loader has 68024 samples
Val loader has 17006 samples
Test loader has 9447 samples
Training fold 1


AssertionError: Torch not compiled with CUDA enabled

In [None]:
def evaluate(model, test_loader, criterion, cfg, device='cuda'):
    model.eval()
    model.to(device)
    running_loss = 0.0
    correct_predictions = 0

    i = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            print('-' * 50)
            print('test iteration', i)
            i += 1
            print('inputs ', inputs.shape)
            print('outputs ', outputs.shape)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data).float()  # Ensure float32 for correct accumulation

    test_loss = running_loss / len(test_loader.dataset)
    accuracy = correct_predictions / len(test_loader.dataset)
    return test_loss, accuracy

criterion = torch.nn.CrossEntropyLoss()

print('=' * 50)
test_loss, test_accuracy = evaluate(model, test_loader, criterion, cfg)
print('=' * 50)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print('=' * 50)
