In [1]:
!pip install pydicom

import os
import shutil
import numpy as np
import pandas as pd 
import cv2
import pydicom
from PIL import Image
import timm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

import timm
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torchvision.transforms.functional as F



Define directories for DICOM and heatmap images


In [2]:
dicom_dir = '/kaggle/input/dicom-files/dicom files'  
heatmap_dir = '/kaggle/input/heatmap-images/heatmap_images'  
fused_dir = '/kaggle/working/fused_images_kaggle'  

os.makedirs(fused_dir, exist_ok=True)

Function to fuse dicoms and heatmaps

In [5]:
def fuse_images_with_resizing(dicom_image, heatmap_image):
    '''
    Function to fuse dicoms and heatmaps
    '''
    heatmap_resized = cv2.resize(heatmap_image, (dicom_image.shape[1], dicom_image.shape[0]))

    heatmap_resized = (heatmap_resized - np.min(heatmap_resized)) / (np.max(heatmap_resized) - np.min(heatmap_resized))

    fused_image = dicom_image * heatmap_resized
    
    return fused_image

Fusing dicoms and corresponding heatmaps

In [6]:
for dicom_filename in os.listdir(dicom_dir):
    if dicom_filename.endswith('.dcm'):
        dicom_path = os.path.join(dicom_dir, dicom_filename)
        dicom_image = pydicom.dcmread(dicom_path).pixel_array

        heatmap_filename = dicom_filename.replace('.dcm', '.png')
        heatmap_path = os.path.join(heatmap_dir, heatmap_filename)

        if os.path.exists(heatmap_path):
            heatmap_image = cv2.imread(heatmap_path, cv2.IMREAD_GRAYSCALE)
            fused_image = fuse_images_with_resizing(dicom_image, heatmap_image)
            dicom_id = dicom_filename.replace('.dcm', '')  
            fused_image_path = os.path.join(fused_dir, f"{dicom_id}.png")
            cv2.imwrite(fused_image_path, fused_image)
        else:
            print(f"Heatmap image not found for {dicom_filename}")

Updating the csv file

In [7]:
csv_file = '/kaggle/input/labels-large/manually_made_large.csv'  
updated_csv_file = '/kaggle/working/updated_csv_file.csv'  

df = pd.read_csv(csv_file)

fused_image_names = [os.path.splitext(f)[0] for f in os.listdir(fused_dir) if f.endswith('.png')]

df_filtered = df[df['fused_image_name'].isin(fused_image_names)]

df_filtered.to_csv(updated_csv_file, index=False)

print(f"Updated CSV file saved at: {updated_csv_file}")

Updated CSV file saved at: /kaggle/working/updated_csv_file.csv


In [8]:
csv_file_path = '/kaggle/working/updated_csv_file.csv'  
data_df = pd.read_csv(csv_file_path)

data_df.head()

Unnamed: 0,fused_image_name,label
0,002da0d9-ce49c30d-4dfcc1f8-746d2401-d8044d48,0
1,0066734a-35568fde-fd52ba23-ec66f3de-88d4aaf9,1
2,00fe73b4-5215bb4f-94bbccc4-ac5f4f6f-52805cfb,1
3,010fa20c-6ac04c8a-f6d4bc0b-eb1e735c-cd940793,2
4,018680d4-8fb864f0-dffebf54-bcba02ab-9b601e7a,0


Define image transformations

In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.RandomHorizontalFlip(p=0.5),  
    transforms.RandomRotation(degrees=15),  
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),  
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),  
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  
    transforms.GaussianBlur(kernel_size=3),  
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Defining FusedImageDataset class

In [10]:
class FusedImageDataset(Dataset):
    def __init__(self, dataframe, image_folder, transform=None):
        self.dataframe = dataframe
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx]['fused_image_name']
        
        if not img_name.endswith('.png'):
            img_name += '.png'
        
        label = self.dataframe.iloc[idx]['label']
        img_path = os.path.join(self.image_folder, img_name)
        
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image file {img_path} does not exist.")
        
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

Creating dataset and dataloader

In [11]:
image_folder = '/kaggle/working/fused_images_kaggle'  

dataset = FusedImageDataset(data_df, image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

len(dataset)

1032

**Multiple transformations and fusing**

In [12]:
# Updated Transform Function for DICOM (Grayscale)
transform_dicom = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.Grayscale(num_output_channels=3),  
    transforms.RandomHorizontalFlip(p=0.5),  
    transforms.RandomRotation(degrees=15),  
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])

# Updated Transform Function for Heatmaps (Already RGB)
transform_heatmap = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Function to convert DICOM to a PIL Image

In [13]:
def dicom_to_pil(dicom_path):
    dicom = pydicom.dcmread(dicom_path)
    
    image = dicom.pixel_array
    image = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255.0
    image = image.astype(np.uint8)
    
    return Image.fromarray(image)

Function to save multiple augmented fused images

In [14]:
def save_multiple_augmented_fused_images(dicom_folder, heatmap_folder, save_folder, num_augmentations=5):
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    for dicom_file in os.listdir(dicom_folder):
        dicom_id = os.path.splitext(dicom_file)[0]  
        
        dicom_path = os.path.join(dicom_folder, dicom_file)
        heatmap_path = os.path.join(heatmap_folder, dicom_id + ".png")  

        dicom_image = dicom_to_pil(dicom_path)
        heatmap_image = Image.open(heatmap_path).convert("RGB")

        for i in range(num_augmentations):
            dicom_image_transformed = transform_dicom(dicom_image)  
            heatmap_image_transformed = transform_heatmap(heatmap_image)

            fused_image = dicom_image_transformed * heatmap_image_transformed
            fused_image = fused_image.permute(1, 2, 0).numpy()  
            fused_image = (fused_image * 255).astype(np.uint8)  

            fused_image_pil = Image.fromarray(fused_image)

            save_path = os.path.join(save_folder, f"{dicom_id}_{i + 1}.png")  

            fused_image_pil.save(save_path)

In [15]:
dicom_folder = '/kaggle/input/dicom-files/dicom files'  
heatmap_folder = '/kaggle/input/heatmap-images/heatmap_images'  
save_folder = '/kaggle/working/multiple_aug_fused'  

save_multiple_augmented_fused_images(dicom_folder, heatmap_folder, save_folder, num_augmentations=15)

**Updating the csv file**

In [16]:
augmented_dir = '/kaggle/working/multiple_aug_fused'  
csv_file = '/kaggle/input/labels-large/manually_made_large.csv'  
aug_csv_file = '/kaggle/working/aug_csv_file.csv'  
df = pd.read_csv(csv_file)

augmented_image_names = [os.path.splitext(f)[0] for f in os.listdir(augmented_dir) if f.endswith('.png')]
augmented_rows = []

for aug_img_name in augmented_image_names:
    original_dicom_id = "_".join(aug_img_name.split("_")[:-1])  

    if original_dicom_id in df['fused_image_name'].values:
        original_rows = df[df['fused_image_name'] == original_dicom_id]

        for index, original_row in original_rows.iterrows():
            new_row = {
                'fused_image_name': aug_img_name,
                'label': original_row['label']  
            }
            augmented_rows.append(new_row)

df_augmented = pd.DataFrame(augmented_rows)
df_updated = pd.concat([df, df_augmented], ignore_index=True)
df_updated.to_csv(aug_csv_file, index=False)

print(f"Updated CSV file with augmented images saved at: {aug_csv_file}")

Updated CSV file with augmented images saved at: /kaggle/working/aug_csv_file.csv


**adding both original and transformed to new folder**

In [17]:
original_fused_dir = '/kaggle/working/fused_images_kaggle'  
augmented_fused_dir = '/kaggle/working/multiple_aug_fused'  

original_images = [f for f in os.listdir(original_fused_dir) if os.path.isfile(os.path.join(original_fused_dir, f))]

counter = 0

for img in original_images:
    img_id = os.path.splitext(img)[0]
    new_img_name = f"{img_id}_0.png" 
    
    src_path = os.path.join(original_fused_dir, img)
    dest_path = os.path.join(augmented_fused_dir, new_img_name)
    shutil.copy(src_path, dest_path)
    
    counter += 1

print(f"Successfully added {counter} original fused images (without transformations) to the directory.")

Successfully added 1032 original fused images (without transformations) to the directory.


In [18]:
untransformed_ids = set(f'dicom_id_{i}' for i in range(1, 6)) 

df_filtered = df_updated[~df_updated['fused_image_name'].apply(lambda x: os.path.basename(x) in untransformed_ids)]

df_filtered.to_csv('updated_original_dataset.csv', index=False)

In [19]:
image_dir = '/kaggle/working/multiple_aug_fused'
csv_file_path = '/kaggle/working/aug_csv_file.csv'
updated_csv_path = 'updated_original_dataset_2.csv'

df = pd.read_csv(csv_file_path)

for filename in os.listdir(image_dir):
    if filename.startswith('dicom_id') and filename.count('_') == 0:  
        file_path = os.path.join(image_dir, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
            print(f"Deleted: {filename}")

remaining_images = set(os.listdir(image_dir))

df_filtered = df[~df['fused_image_name'].isin(remaining_images)]

df_filtered.to_csv(updated_csv_path, index=False)

print(f"Removed entries from the CSV. Updated CSV saved as '{updated_csv_path}'.")

Removed entries from the CSV. Updated CSV saved as 'updated_original_dataset_2.csv'.


In [20]:
import os
import pandas as pd

csv_file = '/kaggle/working/updated_original_dataset_2.csv'  
image_folder = '/kaggle/working/multiple_aug_fused'  

df = pd.read_csv(csv_file)

csv_row_count = len(df)
print(f"Number of entries in the CSV file: {csv_row_count}")

image_count = len([f for f in os.listdir(image_folder) if os.path.isfile(os.path.join(image_folder, f))])
print(f"Number of images in the image folder: {image_count}")

Number of entries in the CSV file: 16512
Number of images in the image folder: 16512


In [21]:
model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=3)

weights_path = "/kaggle/input/vit_model_w_more_aug/tensorflow1/default/1/final_model_20250331_135231.pth" 
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))

model.eval()

print("Model loaded")

  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))


Model loaded


**CBMs**

In [22]:
concepts_csv = '/kaggle/input/final-concepts-latest/FINAL_CONCEPT(latest).csv'
image_dir = '/kaggle/working/multiple_aug_fused'
checkpoint_path = '/kaggle/input/vit_model_w_more_aug/tensorflow1/default/1/final_model_20250331_135231.pth'
log_file = 'lil_overfitting_log_file.txt'  # Log file name

In [23]:
with open(log_file, 'w') as f:
    f.write(f"Training Log - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write("="*50 + "\n\n")

In [24]:
concepts_df = pd.read_csv(concepts_csv)

if 'sid' in concepts_df.columns:
    concepts_df = concepts_df.drop(columns=['sid'])

label_columns = ['normal', 'pneumonia', 'chf']
def combine_labels(row):
    for idx, col in enumerate(label_columns):
        if row[col] == 1:
            return idx
    return -1

concepts_df['label'] = concepts_df.apply(combine_labels, axis=1)
concepts_df = concepts_df.drop(columns=label_columns)

In [25]:
def is_text_column(series, col):
    if col in ['id', 'label']:
        return False
    return series.apply(lambda x: isinstance(x, str)).any()

In [26]:
text_columns = [col for col in concepts_df.columns if is_text_column(concepts_df[col], col)]
concepts_df = concepts_df.drop(columns=text_columns)

In [27]:
all_images = os.listdir(image_dir)
def get_base_id(filename):
    return filename.rsplit('_', 1)[0]

In [28]:
available_base_ids = set(get_base_id(f) for f in all_images)
initial_concept_count = len(concepts_df)
concepts_df = concepts_df[concepts_df['id'].isin(available_base_ids)]
filtered_concept_count = len(concepts_df)

print(f"Initial number of entries in concepts file: {initial_concept_count}")
print(f"Number of entries after filtering missing images: {filtered_concept_count}")
print(f"Number of images in the directory: {len(all_images)}")

Initial number of entries in concepts file: 1072
Number of entries after filtering missing images: 1027
Number of images in the directory: 16512


In [29]:
augmented_data = []
for _, row in concepts_df.iterrows():
    base_id = row['id']
    for suffix in range(1, 11):
        augmented_id = f"{base_id}_{suffix}"
        if any(img.startswith(augmented_id) for img in all_images) and augmented_id not in concepts_df['id'].values:
            augmented_row = row.copy()
            augmented_row['id'] = augmented_id
            augmented_data.append(augmented_row)

if augmented_data:
    augmented_df = pd.DataFrame(augmented_data)
    concepts_df = pd.concat([concepts_df, augmented_df], ignore_index=True)

concept_columns = [col for col in concepts_df.columns if col not in ['id', 'label']]
num_concepts = len(concept_columns)

In [30]:
class StrongAugment:
    def __call__(self, img):
        transform = transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.GaussianBlur(3),
            transforms.ToTensor()
        ])
        return transform(img)

In [31]:
class ImageConceptDataset(Dataset):
    def __init__(self, dataframe, image_dir):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.standard_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.strong_transform = StrongAugment()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        base_id = row['id']
        label = row['label']

        found_image = None
        for suffix in range(10):
            image_path = os.path.join(self.image_dir, f"{base_id}_{suffix}.png")
            if os.path.exists(image_path):
                found_image = image_path
                break

        if found_image is None:
            return None

        image = Image.open(found_image).convert('RGB')
        image = self.strong_transform(image) if label == 1 else self.standard_transform(image)
        concepts = row[concept_columns].values.astype(np.float32)

        return image, torch.tensor(concepts), torch.tensor(label, dtype=torch.long)


In [32]:
train_df, val_df = train_test_split(concepts_df, test_size=0.5, stratify=concepts_df['label'], random_state=42)
train_dataset = [item for item in ImageConceptDataset(train_df, image_dir) if item is not None]
val_dataset = [item for item in ImageConceptDataset(val_df, image_dir) if item is not None]

labels = np.array([label.item() for _, _, label in train_dataset])
class_counts = np.bincount(labels)
class_weights = 1.0 / np.sqrt(class_counts)
sample_weights = class_weights[labels]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(labels), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(label_columns)

class ViTWithConcepts(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super(ViTWithConcepts, self).__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
        hidden_dim = self.vit.head.in_features
        self.vit.head = nn.Identity()
        self.concept_head = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_concepts)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(num_concepts, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        features = self.vit(x)
        concepts = self.concept_head(features)
        logits = self.prediction_head(concepts)
        return concepts, logits

model = ViTWithConcepts(num_concepts, num_classes).to(device)
concept_criterion = nn.CrossEntropyLoss()
classification_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4) 

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [34]:
def train_model(model, train_loader, val_loader, num_epochs=30):
    with open(log_file, 'a') as f:
        f.write(f"Training started with {num_epochs} epochs\n")
        f.write(f"Model architecture: {str(model)}\n")
        f.write(f"Optimizer: {str(optimizer)}\n\n")
        
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            running_loss = 0.0
            correct, total = 0, 0
            all_preds = []
            all_labels = []
            
            # Initialize counters for class-wise accuracy
            class_correct = [0] * num_classes
            class_total = [0] * num_classes
            
            for images, true_concepts, labels in train_loader:
                images, true_concepts, labels = images.to(device), true_concepts.to(device), labels.to(device)
                optimizer.zero_grad()
                predicted_concepts, class_logits = model(images)
                loss = 0.8 * concept_criterion(predicted_concepts, true_concepts) + 0.2 * classification_criterion(class_logits, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                
                _, preds = torch.max(class_logits, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                # Calculate class-wise correct predictions
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_correct[i] += (preds[class_mask] == labels[class_mask]).sum().item()
                    class_total[i] += class_mask.sum().item()
            
            # Calculate training metrics
            train_acc = correct / total
            train_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
            train_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
            train_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
            
            # Calculate class-wise training accuracy
            train_class_acc = []
            for i in range(num_classes):
                if class_total[i] > 0:
                    train_class_acc.append(class_correct[i] / class_total[i])
                else:
                    train_class_acc.append(0.0)
            
            # Validation phase
            model.eval()
            val_correct, val_total = 0, 0
            val_preds = []
            val_labels = []
            
            # Initialize counters for class-wise validation accuracy
            val_class_correct = [0] * num_classes
            val_class_total = [0] * num_classes
            
            with torch.no_grad():
                for images, _, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    _, class_logits = model(images)
                    _, preds = torch.max(class_logits, 1)
                    val_correct += (preds == labels).sum().item()
                    val_total += labels.size(0)
                    val_preds.extend(preds.cpu().numpy())
                    val_labels.extend(labels.cpu().numpy())
                    
                    # Calculate class-wise correct predictions for validation
                    for i in range(num_classes):
                        class_mask = (labels == i)
                        val_class_correct[i] += (preds[class_mask] == labels[class_mask]).sum().item()
                        val_class_total[i] += class_mask.sum().item()
            
            # Calculate validation metrics
            val_acc = val_correct / val_total
            val_precision = precision_score(val_labels, val_preds, average='weighted', zero_division=0)
            val_recall = recall_score(val_labels, val_preds, average='weighted', zero_division=0)
            val_f1 = f1_score(val_labels, val_preds, average='weighted', zero_division=0)
            
            # Calculate class-wise validation accuracy
            val_class_acc = []
            for i in range(num_classes):
                if val_class_total[i] > 0:
                    val_class_acc.append(val_class_correct[i] / val_class_total[i])
                else:
                    val_class_acc.append(0.0)
            
            f.write(f"\nEpoch {epoch + 1}/{num_epochs}\n")
            f.write("-"*30 + "\n")
            f.write(f"Training Loss: {running_loss:.4f}\n")
            f.write(f"Training Accuracy: {train_acc:.4f}\n")
            f.write(f"Training Precision: {train_precision:.4f}\n")
            f.write(f"Training Recall: {train_recall:.4f}\n")
            f.write(f"Training F1 Score: {train_f1:.4f}\n")
            
            f.write("\nTraining Class-wise Accuracy:\n")
            for i, class_name in enumerate(label_columns):
                f.write(f"{class_name}: {train_class_acc[i]:.4f} ({class_correct[i]}/{class_total[i]})\n")
            
            f.write("\nValidation Metrics:\n")
            f.write(f"Validation Accuracy: {val_acc:.4f}\n")
            f.write(f"Validation Precision: {val_precision:.4f}\n")
            f.write(f"Validation Recall: {val_recall:.4f}\n")
            f.write(f"Validation F1 Score: {val_f1:.4f}\n")
            
            f.write("\nValidation Class-wise Accuracy:\n")
            for i, class_name in enumerate(label_columns):
                f.write(f"{class_name}: {val_class_acc[i]:.4f} ({val_class_correct[i]}/{val_class_total[i]})\n")
            
            f.write("-"*50 + "\n")
            
            # Also print to console
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print(f"Training Loss: {running_loss:.4f}")
            print(f"Training Accuracy: {train_acc:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | F1: {train_f1:.4f}")
            
            print("\nTraining Class-wise Accuracy:")
            for i, class_name in enumerate(label_columns):
                print(f"{class_name}: {train_class_acc[i]:.4f} ({class_correct[i]}/{class_total[i]})")
            
            print("\nValidation Metrics:")
            print(f"Validation Accuracy: {val_acc:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f}")
            
            print("\nValidation Class-wise Accuracy:")
            for i, class_name in enumerate(label_columns):
                print(f"{class_name}: {val_class_acc[i]:.4f} ({val_class_correct[i]}/{val_class_total[i]})")
            
    # Final summary
    with open(log_file, 'a') as f:
        f.write("\nTraining completed!\n")

        print(classification_report(val_labels, val_preds))

        cm = confusion_matrix(val_labels, val_preds)
        plt.figure(figsize=(10, 7))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=label_columns, yticklabels=label_columns)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix - Epoch {epoch+1}')
        cm_filename = f'confusion_matrix_epoch_{epoch+1}.png'
        plt.savefig(cm_filename)
        plt.close()
        
        f.write(f"\nLog file and confusion matrices saved in current directory.\n")

In [35]:
train_model(model, train_loader, val_loader)


Epoch 1/30
Training Loss: 1364.5747
Training Accuracy: 0.3740 | Precision: 0.3481 | Recall: 0.3740 | F1: 0.2701

Training Class-wise Accuracy:
normal: 0.0964 (16/166)
pneumonia: 0.0390 (6/154)
chf: 0.8724 (171/196)

Validation Metrics:
Validation Accuracy: 0.3131 | Precision: 0.2158 | Recall: 0.3131 | F1: 0.2455

Validation Class-wise Accuracy:
normal: 0.0000 (0/170)
pneumonia: 0.3171 (52/164)
chf: 0.6102 (108/177)

Epoch 2/30
Training Loss: 1162.1786
Training Accuracy: 0.3081 | Precision: 0.2869 | Recall: 0.3081 | F1: 0.2702

Training Class-wise Accuracy:
normal: 0.0497 (9/181)
pneumonia: 0.4583 (77/168)
chf: 0.4371 (73/167)

Validation Metrics:
Validation Accuracy: 0.3699 | Precision: 0.3374 | Recall: 0.3699 | F1: 0.3060

Validation Class-wise Accuracy:
normal: 0.2588 (44/170)
pneumonia: 0.8110 (133/164)
chf: 0.0678 (12/177)

Epoch 3/30
Training Loss: 1077.0252
Training Accuracy: 0.3992 | Precision: 0.3793 | Recall: 0.3992 | F1: 0.3788

Training Class-wise Accuracy:
normal: 0.5055 (