In [None]:
import os
import json
import random
import logging
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from transformers import (
    AutoTokenizer, 
    AutoModel, 
    AutoConfig,
    ViTImageProcessor, 
    ViTModel,
    get_linear_schedule_with_warmup
)

# Get a logger instance for the current module
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  

log_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

if not logger.handlers:
    log_file_name = f"vqa_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_file_path = f"/kaggle/working/{log_file_name}"
    
    file_handler = logging.FileHandler(log_file_path)
    file_handler.setLevel(logging.INFO)  
    file_handler.setFormatter(log_formatter)
    logger.addHandler(file_handler)

    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)  
    stream_handler.setFormatter(log_formatter)
    logger.addHandler(stream_handler)
    
    logger.info(f"Logging initialized. Log file will be at: {log_file_path}")
else:
    logger.info("Logger already initialized with handlers.")

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Config
CONFIG = {
    'data_dir': '/kaggle/input/vizwiz/',
    'image_size': 384,
    'batch_size': 16,
    'num_workers': 4,
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'num_epochs': 10,
    'warmup_ratio': 0.1,
    'max_grad_norm': 1.0,
    'text_model': 'bert-base-uncased',
    'vision_model': 'google/vit-base-patch16-384',
    'hidden_size': 768,
    'dropout': 0.1,
    'save_dir': './models',
    'seed': 42
}

# Create save directory if it doesn't exist
os.makedirs(CONFIG['save_dir'], exist_ok=True)

In [None]:
# VizWiz Dataset
class VizWizDataset(Dataset):
    def __init__(self, annotations_file, img_dir, processor, tokenizer, max_length=128, split='train'):
        self.annotations = json.load(open(annotations_file, 'r'))
        self.img_dir = img_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        
        # Create answer vocabulary from training data
        if split == 'train':
            self.answer_vocab = self._create_answer_vocab()
            # Save vocab for inference
            with open('answer_vocab.json', 'w') as f:
                json.dump(self.answer_vocab, f)
        else:
            # Load vocab for val/test
            try:
                with open('answer_vocab.json', 'r') as f:
                    self.answer_vocab = json.load(f)
            except FileNotFoundError:
                logger.warning("Answer vocabulary not found, creating from current split (not recommended)")
                self.answer_vocab = self._create_answer_vocab()
                
        logger.info(f"Loaded {len(self.annotations)} samples for {split} split")
        logger.info(f"Answer vocabulary size: {len(self.answer_vocab)}")
        
    def _create_answer_vocab(self):
        """Create a vocabulary of answers and map them to indices"""
        answers = []
        for item in self.annotations:
            # Get all answers with confidence "yes"
            confident_answers = [a['answer'].lower() for a in item['answers'] 
                               if a['answer_confidence'] == 'yes']
            if confident_answers:
                # Use the most common answer
                from collections import Counter
                most_common = Counter(confident_answers).most_common(1)[0][0]
                answers.append(most_common)
        
        # Add special tokens
        special_tokens = ['unanswerable', 'unknown']
        unique_answers = special_tokens + list(set(answers))
        
        # Create vocab dictionary
        answer_to_idx = {ans: idx for idx, ans in enumerate(unique_answers)}
        idx_to_answer = {idx: ans for idx, ans in enumerate(unique_answers)}
        
        return {
            'answer_to_idx': answer_to_idx,
            'idx_to_answer': idx_to_answer
        }
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        item = self.annotations[idx]
        
        # Load and preprocess image
        image_path = os.path.join(self.img_dir, item['image'])
        try:
            image = Image.open(image_path).convert('RGB')
            image_encoding = self.processor(images=image, return_tensors="pt")
            image_encoding = {k: v.squeeze(0) for k, v in image_encoding.items()}
        except Exception as e:
            logger.error(f"Error processing image {image_path}: {e}")
            # Use a black image as fallback
            image = Image.new('RGB', (CONFIG['image_size'], CONFIG['image_size']), color=0)
            image_encoding = self.processor(images=image, return_tensors="pt")
            image_encoding = {k: v.squeeze(0) for k, v in image_encoding.items()}
        
        # Process question
        question = item['question']
        question_encoding = self.tokenizer(
            question,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        question_encoding = {k: v.squeeze(0) for k, v in question_encoding.items()}
        
        # Process answer
        answerable = item['answerable']
        if answerable == 0:
            answer_idx = self.answer_vocab['answer_to_idx']['unanswerable']
        else:
            # Get most frequent answer with 'yes' confidence
            yes_answers = [a['answer'].lower() for a in item['answers'] 
                          if a['answer_confidence'] == 'yes']
            
            if yes_answers:
                from collections import Counter
                most_common = Counter(yes_answers).most_common(1)[0][0]
                # If answer is not in vocabulary, use 'unknown'
                answer_idx = self.answer_vocab['answer_to_idx'].get(
                    most_common, 
                    self.answer_vocab['answer_to_idx']['unknown']
                )
            else:
                answer_idx = self.answer_vocab['answer_to_idx']['unknown']
                
        return {
            'image_encoding': image_encoding,
            'question_encoding': question_encoding,
            'answer_idx': torch.tensor(answer_idx, dtype=torch.long),
            'answerable': torch.tensor(answerable, dtype=torch.long),
            'image_id': item['image'],
            'question': question
        }

# Data collator
def collate_fn(batch):
    image_encodings = {
        'pixel_values': torch.stack([item['image_encoding']['pixel_values'] for item in batch])
    }
    
    input_ids = torch.stack([item['question_encoding']['input_ids'] for item in batch])
    attention_mask = torch.stack([item['question_encoding']['attention_mask'] for item in batch])
    
    question_encodings = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    answer_idxs = torch.stack([item['answer_idx'] for item in batch])
    answerable = torch.stack([item['answerable'] for item in batch])
    
    image_ids = [item['image_id'] for item in batch]
    questions = [item['question'] for item in batch]
    
    return {
        'image_encodings': image_encodings,
        'question_encodings': question_encodings,
        'answer_idxs': answer_idxs,
        'answerable': answerable,
        'image_ids': image_ids,
        'questions': questions
    }

# Load data
def load_data(config):
    # Initialize models for preprocessing
    tokenizer = AutoTokenizer.from_pretrained(config['text_model'])
    processor = ViTImageProcessor.from_pretrained(config['vision_model'])
    
    # Data paths
    train_ann_path = os.path.join(config['data_dir'], 'Annotations/Annotations/train.json')
    val_ann_path = os.path.join(config['data_dir'], 'Annotations/Annotations/val.json')
    test_ann_path = os.path.join(config['data_dir'], 'Annotations/Annotations/test.json')
    
    train_img_dir = os.path.join(config['data_dir'], 'train/train/')
    val_img_dir = os.path.join(config['data_dir'], 'val/val/')
    test_img_dir = os.path.join(config['data_dir'], 'test/test/')
    
    # Create datasets
    train_dataset = VizWizDataset(
        train_ann_path, train_img_dir, processor, tokenizer, split='train'
    )
    val_dataset = VizWizDataset(
        val_ann_path, val_img_dir, processor, tokenizer, split='val'
    )
    test_dataset = VizWizDataset(
        test_ann_path, test_img_dir, processor, tokenizer, split='test'
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        collate_fn=collate_fn
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        collate_fn=collate_fn
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        collate_fn=collate_fn
    )
    
    return train_loader, val_loader, test_loader, train_dataset.answer_vocab

In [None]:
# VQA Model
class VQAModel(nn.Module):
    def __init__(self, config, num_answers):
        super(VQAModel, self).__init__()
        self.config = config
        self.num_answers = num_answers
        
        # Vision encoder
        self.vision_config = AutoConfig.from_pretrained(config['vision_model'])
        self.vision_encoder = ViTModel.from_pretrained(config['vision_model'])
        
        # Text encoder
        self.text_config = AutoConfig.from_pretrained(config['text_model'])
        self.text_encoder = AutoModel.from_pretrained(config['text_model'])
        
        # Projection layers
        self.vision_projection = nn.Linear(
            self.vision_config.hidden_size, config['hidden_size']
        )
        self.text_projection = nn.Linear(
            self.text_config.hidden_size, config['hidden_size']
        )
        
        # Multimodal fusion
        self.fusion = nn.Sequential(
            nn.Linear(2 * config['hidden_size'], config['hidden_size']),
            nn.LayerNorm(config['hidden_size']),
            nn.GELU(),
            nn.Dropout(config['dropout'])
        )
        
        # Answer prediction
        self.classifier = nn.Sequential(
            nn.Linear(config['hidden_size'], config['hidden_size']),
            nn.LayerNorm(config['hidden_size']),
            nn.GELU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_size'], num_answers)
        )
        
        # Answerable prediction
        self.answerable_classifier = nn.Sequential(
            nn.Linear(config['hidden_size'], config['hidden_size'] // 2),
            nn.LayerNorm(config['hidden_size'] // 2),
            nn.GELU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_size'] // 2, 2)  # Binary classification
        )
        
    def forward(self, image_encodings, question_encodings):
        # Process image
        vision_outputs = self.vision_encoder(**image_encodings)
        vision_embeds = vision_outputs.last_hidden_state[:, 0]  # CLS token
        vision_embeds = self.vision_projection(vision_embeds)
        
        # Process text
        text_outputs = self.text_encoder(**question_encodings)
        text_embeds = text_outputs.last_hidden_state[:, 0]  # CLS token
        text_embeds = self.text_projection(text_embeds)
        
        # Combine modalities
        multimodal_features = torch.cat([vision_embeds, text_embeds], dim=1)
        fused_features = self.fusion(multimodal_features)
        
        # Predict answers and answerable
        answer_logits = self.classifier(fused_features)
        answerable_logits = self.answerable_classifier(fused_features)
        
        return {
            'answer_logits': answer_logits,
            'answerable_logits': answerable_logits,
            'fused_features': fused_features
        }

In [None]:
# Training and evaluation functions
def train_one_epoch(model, data_loader, optimizer, scheduler, criterion, device, epoch, config):
    model.train()
    
    losses = []
    answer_preds = []
    answer_targets = []
    answerable_preds = []
    answerable_targets = []
    
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")
    
    for batch in progress_bar:
        # Move data to device
        image_encodings = {k: v.to(device) for k, v in batch['image_encodings'].items()}
        question_encodings = {k: v.to(device) for k, v in batch['question_encodings'].items()}
        answer_idxs = batch['answer_idxs'].to(device)
        answerable = batch['answerable'].to(device)
        
        # Forward pass
        outputs = model(image_encodings, question_encodings)
        
        # Calculate loss
        answer_loss = criterion(outputs['answer_logits'], answer_idxs)
        answerable_loss = criterion(outputs['answerable_logits'], answerable)
        
        # Combined loss
        loss = answer_loss + answerable_loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        
        # Update weights
        optimizer.step()
        scheduler.step()
        
        # Record predictions and targets
        answer_pred = torch.argmax(outputs['answer_logits'], dim=1)
        answerable_pred = torch.argmax(outputs['answerable_logits'], dim=1)
        
        answer_preds.extend(answer_pred.detach().cpu().numpy())
        answer_targets.extend(answer_idxs.detach().cpu().numpy())
        answerable_preds.extend(answerable_pred.detach().cpu().numpy())
        answerable_targets.extend(answerable.detach().cpu().numpy())
        
        # Update progress bar
        losses.append(loss.item())
        avg_loss = sum(losses) / len(losses)
        progress_bar.set_postfix({
            'loss': f"{avg_loss:.4f}",
        })
    
    # Calculate metrics
    answer_accuracy = np.mean(np.array(answer_preds) == np.array(answer_targets))
    answerable_accuracy = np.mean(np.array(answerable_preds) == np.array(answerable_targets))
    
    return {
        'loss': np.mean(losses),
        'answer_accuracy': answer_accuracy,
        'answerable_accuracy': answerable_accuracy
    }

def evaluate(model, data_loader, criterion, device, epoch, config, split='Val'):
    model.eval()
    
    losses = []
    answer_preds = []
    answer_targets = []
    answerable_preds = []
    answerable_targets = []
    
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [{split}]")
    
    with torch.no_grad():
        for batch in progress_bar:
            # Move data to device
            image_encodings = {k: v.to(device) for k, v in batch['image_encodings'].items()}
            question_encodings = {k: v.to(device) for k, v in batch['question_encodings'].items()}
            answer_idxs = batch['answer_idxs'].to(device)
            answerable = batch['answerable'].to(device)
            
            # Forward pass
            outputs = model(image_encodings, question_encodings)
            
            # Calculate loss
            answer_loss = criterion(outputs['answer_logits'], answer_idxs)
            answerable_loss = criterion(outputs['answerable_logits'], answerable)
            
            # Combined loss
            loss = answer_loss + answerable_loss
            
            # Record predictions and targets
            answer_pred = torch.argmax(outputs['answer_logits'], dim=1)
            answerable_pred = torch.argmax(outputs['answerable_logits'], dim=1)
            
            answer_preds.extend(answer_pred.detach().cpu().numpy())
            answer_targets.extend(answer_idxs.detach().cpu().numpy())
            answerable_preds.extend(answerable_pred.detach().cpu().numpy())
            answerable_targets.extend(answerable.detach().cpu().numpy())
            
            # Update progress bar
            losses.append(loss.item())
            avg_loss = sum(losses) / len(losses)
            progress_bar.set_postfix({
                'loss': f"{avg_loss:.4f}",
            })
    
    # Calculate metrics
    answer_accuracy = np.mean(np.array(answer_preds) == np.array(answer_targets))
    answerable_accuracy = np.mean(np.array(answerable_preds) == np.array(answerable_targets))
    
    return {
        'loss': np.mean(losses),
        'answer_accuracy': answer_accuracy,
        'answerable_accuracy': answerable_accuracy
    }

In [None]:
# Main training function
def train_model(config):
    logger.info("Starting model training...")
    logger.info(f"Configuration: {config}")
    
    # Load data
    train_loader, val_loader, test_loader, answer_vocab = load_data(config)
    num_answers = len(answer_vocab['answer_to_idx'])
    logger.info(f"Number of answers in vocabulary: {num_answers}")
    
    # Initialize model
    model = VQAModel(config, num_answers)
    model.to(device)
    
    # Calculate trainable parameters
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model has {param_count:,} trainable parameters")
    
    # Initialize optimizer and scheduler
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    total_steps = len(train_loader) * config['num_epochs']
    warmup_steps = int(total_steps * config['warmup_ratio'])
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_val_accuracy = 0
    metrics_history = {
        'train_loss': [],
        'train_answer_acc': [],
        'train_answerable_acc': [],
        'val_loss': [],
        'val_answer_acc': [],
        'val_answerable_acc': []
    }
    
    for epoch in range(config['num_epochs']):
        # Train
        train_metrics = train_one_epoch(
            model, train_loader, optimizer, scheduler, criterion, device, epoch, config
        )
        
        # Validate
        val_metrics = evaluate(
            model, val_loader, criterion, device, epoch, config
        )
        
        # Log metrics
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        logger.info(f"Train Loss: {train_metrics['loss']:.4f}, "
                   f"Answer Acc: {train_metrics['answer_accuracy']:.4f}, "
                   f"Answerable Acc: {train_metrics['answerable_accuracy']:.4f}")
        logger.info(f"Val Loss: {val_metrics['loss']:.4f}, "
                   f"Answer Acc: {val_metrics['answer_accuracy']:.4f}, "
                   f"Answerable Acc: {val_metrics['answerable_accuracy']:.4f}")
        
        # Update metrics history
        metrics_history['train_loss'].append(train_metrics['loss'])
        metrics_history['train_answer_acc'].append(train_metrics['answer_accuracy'])
        metrics_history['train_answerable_acc'].append(train_metrics['answerable_accuracy'])
        metrics_history['val_loss'].append(val_metrics['loss'])
        metrics_history['val_answer_acc'].append(val_metrics['answer_accuracy'])
        metrics_history['val_answerable_acc'].append(val_metrics['answerable_accuracy'])
        
        # Save model if it's the best so far
        val_accuracy = val_metrics['answer_accuracy']
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            
            # Save model
            model_path = os.path.join(config['save_dir'], f"vqa_model_best.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_accuracy,
                'config': config,
                'answer_vocab': answer_vocab
            }, model_path)
            logger.info(f"Saved best model with val accuracy: {val_accuracy:.4f}")
        
        # Save checkpoint every epoch
        checkpoint_path = os.path.join(config['save_dir'], f"vqa_model_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': val_accuracy,
            'config': config,
            'answer_vocab': answer_vocab
        }, checkpoint_path)
    
    # Plot training history
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    plt.plot(metrics_history['train_loss'], label='Train')
    plt.plot(metrics_history['val_loss'], label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss')
    
    plt.subplot(2, 2, 2)
    plt.plot(metrics_history['train_answer_acc'], label='Train')
    plt.plot(metrics_history['val_answer_acc'], label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Answer Accuracy')
    
    plt.subplot(2, 2, 3)
    plt.plot(metrics_history['train_answerable_acc'], label='Train')
    plt.plot(metrics_history['val_answerable_acc'], label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Answerable Accuracy')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config['save_dir'], 'training_history.png'))
    
    # Evaluate on test set
    logger.info("Evaluating on test set...")
    
    # Load best model
    checkpoint = torch.load(os.path.join(config['save_dir'], f"vqa_model_best.pt"))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    test_metrics = evaluate(
        model, test_loader, criterion, device, -1, config, split='Test'
    )
    
    logger.info(f"Test Loss: {test_metrics['loss']:.4f}, "
               f"Answer Acc: {test_metrics['answer_accuracy']:.4f}, "
               f"Answerable Acc: {test_metrics['answerable_accuracy']:.4f}")
    
    return model, answer_vocab

In [None]:
# Inference function
def predict(model, image_path, question, processor, tokenizer, answer_vocab, device, config):
    model.eval()
    
    # Preprocess image
    image = Image.open(image_path).convert('RGB')
    image_encoding = processor(images=image, return_tensors="pt")
    image_encoding = {k: v.to(device) for k, v in image_encoding.items()}
    
    # Preprocess question
    question_encoding = tokenizer(
        question,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    question_encoding = {k: v.to(device) for k, v in question_encoding.items()}
    
    # Get predictions
    with torch.no_grad():
        outputs = model(image_encoding, question_encoding)
        
        answer_logits = outputs['answer_logits']
        answerable_logits = outputs['answerable_logits']
        
        answer_idx = torch.argmax(answer_logits, dim=1).item()
        answerable_idx = torch.argmax(answerable_logits, dim=1).item()
        
        answer = answer_vocab['idx_to_answer'][str(answer_idx)]
        is_answerable = bool(answerable_idx)
        
        # Get confidence scores
        answer_probs = torch.softmax(answer_logits, dim=1)[0]
        answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
        
        answer_confidence = answer_probs[answer_idx].item()
        answerable_confidence = answerable_probs[answerable_idx].item()
    
    return {
        'answer': answer,
        'answer_confidence': answer_confidence,
        'is_answerable': is_answerable,
        'answerable_confidence': answerable_confidence
    }

# Visualization function for demo
def visualize_prediction(image_path, question, prediction):
    plt.figure(figsize=(10, 8))
    
    image = Image.open(image_path).convert('RGB')
    plt.imshow(image)
    plt.axis('off')
    
    is_answerable = "Yes" if prediction['is_answerable'] else "No"
    
    plt.title(f"Q: {question}\n"
              f"A: {prediction['answer']} (Confidence: {prediction['answer_confidence']:.2f})\n"
              f"Answerable: {is_answerable} (Confidence: {prediction['answerable_confidence']:.2f})")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Run the entire training pipeline
if __name__ == "__main__":
    # Train model
    model, answer_vocab = train_model(CONFIG)
    
    # Save final model
    final_model_path = os.path.join(CONFIG['save_dir'], "vqa_model_final.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': CONFIG,
        'answer_vocab': answer_vocab
    }, final_model_path)
    logger.info(f"Saved final model to {final_model_path}")
    
    # # Example of how to load the model for inference
    # logger.info("Loading model for inference...")
    
    # # Initialize preprocessors
    # processor = ViTImageProcessor.from_pretrained(CONFIG['vision_model'])
    # tokenizer = AutoTokenizer.from_pretrained(CONFIG['text_model'])
    
    # # Load model
    # checkpoint = torch.load(final_model_path)
    # loaded_model = VQAModel(CONFIG, len(checkpoint['answer_vocab']['answer_to_idx']))
    # loaded_model.load_state_dict(checkpoint['model_state_dict'])
    # loaded_model.to(device)
    
    # logger.info("Model loaded successfully!")