In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torchvision.models import ResNet34_Weights
from torch.utils.data import DataLoader, Dataset, Subset
from PIL import Image
from collections import defaultdict, Counter
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from sklearn.model_selection import KFold

In [None]:
with open("config.json", "r") as f:
    config = json.load(f)
DATASET_PATH = config["DATASET_PATH"]

In [None]:
class ChestXrayDataset3Clases(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        labels_map = {"NORMAL": 0, "BACTERIA": 1, "VIRUS": 2}
        for folder, label in labels_map.items():
            folder_path = os.path.join(root_dir, "PNEUMONIA") if folder != "NORMAL" else os.path.join(root_dir, "NORMAL")
            if folder != "NORMAL":
                folder_path = os.path.join(folder_path, folder)
            if not os.path.exists(folder_path):
                continue
            for root, _, files in os.walk(folder_path):
                for file in files:
                    if file.lower().endswith(('.jpeg', '.jpg', '.png')):
                        self.samples.append((os.path.join(root, file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


In [None]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomAffine(degrees=10, translate=(0.1,0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225])
])

In [None]:
class TransformedSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


In [None]:
train_data = ChestXrayDataset3Clases(os.path.join(DATASET_PATH, "train"), transform=None)
val_data = ChestXrayDataset3Clases(os.path.join(DATASET_PATH, "val"), transform=None)
full_data = torch.utils.data.ConcatDataset([train_data, val_data])

test_data  = ChestXrayDataset3Clases(os.path.join(DATASET_PATH, "test"), val_transform)
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("cuda")
else:
    try:
        import torch_directml
        device = torch_directml.device()
        print("amd")
    except ImportError:
        device = torch.device("cpu")
        print("cpu")


In [None]:
class ResNet34FineTune(nn.Module):
    def __init__(self, num_classes=3):
        super(ResNet34FineTune, self).__init__()
        self.resnet = models.resnet34(weights=ResNet34_Weights.DEFAULT)
        for param in self.resnet.parameters():
            param.requires_grad = False
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.45),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.35),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        features = self.resnet(x)
        return self.fc(features)


In [None]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)
results = {}

all_train_losses, all_val_losses = [], []
all_train_accs, all_val_accs = [], []