In [9]:
import os
# Torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight
# Torch Vision
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
# Transformers
from transformers import ViTImageProcessor, ViTForImageClassification
# Utils
from tqdm import tqdm
import numpy as np
# Train Test Evaluate
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
# Visualization
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

In [10]:
g_vit_patch = 'google/vit-base-patch16-224'
base_path = '/kaggle/input/ecg-image-data' # For training Data
base_output_path = '/kaggle/working' # For Storing weights and all
# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 1
NUM_WORKERS = 2

In [11]:
class CNNModel(nn.Module):
    def __init__(self, num_classes, device, cnn_model=models.resnet18(weights='ResNet18_Weights.DEFAULT')):
        super(CNNModel, self).__init__()
        self.cnn = cnn_model
        num_ftrs = self.cnn.fc.in_features
        self.cnn.fc = nn.Linear(num_ftrs, num_classes)
        self.device = device

    def forward(self, x):
        # Ensure input is on the correct device
        x = x.to(self.device)
        return self.cnn(x)

# ViT Model with Explicit Device Handling
class ViTModel(nn.Module):
    def __init__(self, num_classes, device, vit_model_name=g_vit_patch):
        super(ViTModel, self).__init__()
        self.feature_extractor = ViTImageProcessor.from_pretrained(vit_model_name,do_rescale=False)
        self.model = ViTForImageClassification.from_pretrained(
            vit_model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True,
        )
        self.device = device

    def forward(self, x):
        # Ensure input is on the correct device
        x = x.to(self.device)  # Move to the specified device
        # Prepare inputs for ViT model
        feature_input = self.feature_extractor(x, return_tensors="pt").to(self.device)
        logits = self.model(**feature_input).logits
        return logits

# Ensemble Model with Explicit Device Handling
class EnsembleModel(nn.Module):
    def __init__(self, num_classes, cnn_model, vit_model, device):
        super(EnsembleModel, self).__init__()
        self.cnn_model = cnn_model
        self.vit_model = vit_model
        self.fc = nn.Linear(num_classes * 2, num_classes)  # Final linear layer
        self.device = device

    def forward(self, x):
        # Ensure input is on the correct device
        x = x.to(self.device)

        # Forward through CNN and ViT
        cnn_output = self.cnn_model(x)  # CNN output
        vit_output = self.vit_model(x)  # ViT output

        # Combine outputs
        combined_output = torch.cat((cnn_output, vit_output), dim=1)  # Concatenate

        # Final linear output
        return self.fc(combined_output)

    def eval(self, *args, **kwargs):
        super().eval(*args, **kwargs)  # Call eval() on the parent class

        # Call eval() on the underlying models
        self.cnn_model.eval(*args, **kwargs)
        self.vit_model.eval(*args, **kwargs)

    def train(self, *args, **kwargs):
        super().train( *args, **kwargs)  # Call train() on the parent class

        # Call train() on the underlying models
        self.cnn_model.train( *args, **kwargs)
        self.vit_model.train( *args, **kwargs)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 170MB/s]


In [12]:
def preprocess(base_path: str, device: torch.device):
    # Your transformations
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    print("Loading Datasets...")
    # Load datasets
    train_dataset = datasets.ImageFolder(
        root=os.path.join(base_path, "train"),
        transform=train_transform
    )

    test_dataset = datasets.ImageFolder(
        root=os.path.join(base_path, "test"),
        transform=test_transform
    )

    print("Balancing Training Datasets...")

    # Assuming you have a dataset 'dataset' with labels 'labels'
    # Calculate class counts
    labels = torch.tensor(train_dataset.targets)
    train_class_counts = torch.bincount(labels)
    train_desired_count = torch.min(train_class_counts)

    # Calculate weights for each sample based on class counts and desired count
    weights = train_desired_count / train_class_counts.float()  # Adjusted weights based on desired count

    sample_weights = weights[labels]

    # Create WeightedRandomSampler with the calculated weights
    num_samples = int(train_desired_count * len(train_class_counts))
    train_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

    print("Preparing the loaders...")
    # Load data using DataLoader with the specified sampler
    train_loader = DataLoader(train_dataset, num_workers=4, batch_size=75, sampler=train_sampler)
    test_loader = DataLoader(test_dataset, num_workers=4, batch_size=75, shuffle=True)

    print(f"{train_class_counts=} {train_desired_count=}")
    print("All Done...")

    return train_dataset, test_dataset, train_loader, test_loader


In [13]:
data_folder = os.path.join(base_path, 'ECG_Image_data')

train_dataset, test_dataset, train_loader, test_loader = preprocess(data_folder, device)


Loading Datasets...


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/ecg-image-data/ECG_Image_data/train'