In [1]:
import os
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm


import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, vgg16
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split



import warnings
warnings.filterwarnings("ignore")

np.random.seed(1234)




In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [3]:
data = pd.read_csv('y_train.csv')

In [4]:
data.head()

Unnamed: 0,file_id,cell_line
0,1,MCF7
1,2,RT4
2,3,U-2 OS
3,4,RT4
4,5,A549


In [5]:
train_data, val_data = train_test_split(
    data, test_size=0.2, random_state=42, stratify=data['cell_line'])


In [6]:
train_data.to_csv("train_data.csv", index=False)
val_data.to_csv("val_data.csv", index=False)


In [7]:
class CellLineDataset(Dataset):
    def __init__(self, img_dir, labels_file=None, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        if labels_file:
            self.labels_df = pd.read_csv(labels_file)
            self.has_labels = True
            # Create a dictionary mapping class names to integers
            self.class_to_idx = {class_name: i for i, class_name in enumerate(
                self.labels_df["cell_line"].unique())}
        else:
            self.has_labels = False

    def __len__(self):
        return len(self.labels_df)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if self.has_labels:
            row = self.labels_df.iloc[idx]
            sample_id = row['file_id']
            img_paths = [os.path.join(self.img_dir, f"{str(sample_id).zfill(5)}_{color}.png") for color in ["blue", "red", "yellow"]]
            imgs = [Image.open(img_path) for img_path in img_paths]
            img = torch.stack([torchvision.transforms.functional.to_tensor(im) for im in imgs]).squeeze(1)

            if self.transform:
                img = self.transform(img)

            # Convert label to integer
            label = self.class_to_idx[row['cell_line']]
            return img, label
        else:
            raise IndexError(f"No matching row found for index {idx}")


In [8]:
def calculate_mean_std(loader):
    mean = 0.
    std = 0.
    nb_samples = 0.
    for data, _ in tqdm(loader):
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples
    return mean, std

In [9]:
epochs = 100
batch_size = 32
lr = 0.01
step_size = 5
gamma = 0.1

In [10]:
raw_train_data = CellLineDataset(
    img_dir="images_train/images_train/", labels_file="y_train.csv")
raw_train_loader = DataLoader(raw_train_data, batch_size=batch_size, shuffle=True)

mean, std = calculate_mean_std(raw_train_loader)




100%|██████████| 301/301 [02:56<00:00,  1.70it/s]


In [11]:
transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=mean, std=std)
])


In [12]:
train_dataset = CellLineDataset(
    img_dir="images_train/images_train/", labels_file="train_data.csv", transform=transform)
val_dataset = CellLineDataset(
    img_dir="images_train/images_train/", labels_file='val_data.csv', transform=transform)


In [13]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


In [14]:

model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 9)
model = model.to(device)


In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


In [17]:
for epoch in range(1, epochs +1):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader)
    for inputs, labels in pbar:
        # Move inputs and labels to device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item() * inputs.size(0)

    # Calculate average losses
    epoch_loss = running_loss / len(train_loader.dataset)
    print('Training Loss: {:.4f}'.format(epoch_loss))
    scheduler.step()
    # Validation on the test set
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)

    val_loss = running_loss / len(test_loader.dataset)
    print('Validation Loss: {:.4f}'.format(val_loss))


100%|██████████| 241/241 [00:56<00:00,  4.29it/s]


Training Loss: 2.1185
Validation Loss: 2.6670


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.8155
Validation Loss: 6.3097


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.7454
Validation Loss: 2.7935


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.6596
Validation Loss: 3.7883


100%|██████████| 241/241 [00:43<00:00,  5.52it/s]


Training Loss: 1.6351
Validation Loss: 3.3095


100%|██████████| 241/241 [00:43<00:00,  5.59it/s]


Training Loss: 1.4871
Validation Loss: 3.0460


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.4055
Validation Loss: 3.4955


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.3498
Validation Loss: 3.4993


100%|██████████| 241/241 [00:45<00:00,  5.36it/s]


Training Loss: 1.2847
Validation Loss: 3.5510


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.2268
Validation Loss: 3.5263


100%|██████████| 241/241 [00:45<00:00,  5.36it/s]


Training Loss: 1.1384
Validation Loss: 3.6668


100%|██████████| 241/241 [00:45<00:00,  5.33it/s]


Training Loss: 1.0931
Validation Loss: 3.7642


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0677
Validation Loss: 3.8428


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0559
Validation Loss: 3.7841


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0535
Validation Loss: 3.9823


100%|██████████| 241/241 [00:43<00:00,  5.48it/s]


Training Loss: 1.0181
Validation Loss: 3.8748


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0179
Validation Loss: 3.8920


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0254
Validation Loss: 3.9878


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0193
Validation Loss: 4.0181


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 0.9962
Validation Loss: 4.0431


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 1.0022
Validation Loss: 4.0000


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 1.0032
Validation Loss: 4.0466


100%|██████████| 241/241 [00:43<00:00,  5.56it/s]


Training Loss: 1.0174
Validation Loss: 3.9548


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 1.0146
Validation Loss: 3.9168


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 1.0049
Validation Loss: 3.9552


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 0.9926
Validation Loss: 3.9254


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 0.9999
Validation Loss: 4.0207


100%|██████████| 241/241 [00:44<00:00,  5.42it/s]


Training Loss: 0.9988
Validation Loss: 4.0279


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0152
Validation Loss: 3.9344


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0114
Validation Loss: 4.0482


100%|██████████| 241/241 [00:43<00:00,  5.56it/s]


Training Loss: 0.9948
Validation Loss: 4.1028


100%|██████████| 241/241 [00:43<00:00,  5.57it/s]


Training Loss: 1.0073
Validation Loss: 4.0754


100%|██████████| 241/241 [00:43<00:00,  5.58it/s]


Training Loss: 1.0054
Validation Loss: 3.9413


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0098
Validation Loss: 4.0399


100%|██████████| 241/241 [00:45<00:00,  5.33it/s]


Training Loss: 1.0038
Validation Loss: 4.0111


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.0000
Validation Loss: 3.9055


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 0.9904
Validation Loss: 3.9922


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0064
Validation Loss: 3.9868


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0045
Validation Loss: 3.9235


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0158
Validation Loss: 3.9416


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0060
Validation Loss: 4.0556


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0139
Validation Loss: 3.9626


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0033
Validation Loss: 3.9285


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.0124
Validation Loss: 3.8854


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0056
Validation Loss: 4.0410


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.0112
Validation Loss: 4.0516


100%|██████████| 241/241 [00:45<00:00,  5.32it/s]


Training Loss: 0.9983
Validation Loss: 3.9727


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0055
Validation Loss: 3.9761


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0102
Validation Loss: 4.0459


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 0.9992
Validation Loss: 4.0038


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0285
Validation Loss: 3.9996


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0044
Validation Loss: 3.9667


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 0.9886
Validation Loss: 3.9580


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0065
Validation Loss: 3.9762


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0063
Validation Loss: 4.0328


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0041
Validation Loss: 4.0098


100%|██████████| 241/241 [00:45<00:00,  5.33it/s]


Training Loss: 1.0062
Validation Loss: 3.9939


100%|██████████| 241/241 [00:44<00:00,  5.42it/s]


Training Loss: 1.0026
Validation Loss: 3.9032


100%|██████████| 241/241 [00:44<00:00,  5.44it/s]


Training Loss: 1.0145
Validation Loss: 4.0517


100%|██████████| 241/241 [00:44<00:00,  5.47it/s]


Training Loss: 1.0001
Validation Loss: 3.8331


100%|██████████| 241/241 [00:43<00:00,  5.53it/s]


Training Loss: 1.0148
Validation Loss: 3.9132


100%|██████████| 241/241 [00:43<00:00,  5.52it/s]


Training Loss: 1.0018
Validation Loss: 3.9858


100%|██████████| 241/241 [00:43<00:00,  5.57it/s]


Training Loss: 1.0003
Validation Loss: 4.0188


100%|██████████| 241/241 [00:43<00:00,  5.57it/s]


Training Loss: 1.0029
Validation Loss: 3.9380


100%|██████████| 241/241 [00:44<00:00,  5.40it/s]


Training Loss: 1.0020
Validation Loss: 3.9818


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0211
Validation Loss: 4.1049


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0116
Validation Loss: 3.9409


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0091
Validation Loss: 4.0467


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0052
Validation Loss: 3.9683


100%|██████████| 241/241 [00:45<00:00,  5.33it/s]


Training Loss: 1.0043
Validation Loss: 3.8822


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0021
Validation Loss: 4.0142


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 0.9924
Validation Loss: 3.9166


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0121
Validation Loss: 3.9341


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0000
Validation Loss: 3.9581


100%|██████████| 241/241 [00:45<00:00,  5.34it/s]


Training Loss: 1.0085
Validation Loss: 4.0151


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0166
Validation Loss: 3.9434


100%|██████████| 241/241 [00:45<00:00,  5.36it/s]


Training Loss: 1.0076
Validation Loss: 3.9633


100%|██████████| 241/241 [00:45<00:00,  5.35it/s]


Training Loss: 1.0040
Validation Loss: 3.8746


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0121
Validation Loss: 3.9694


100%|██████████| 241/241 [00:44<00:00,  5.37it/s]


Training Loss: 1.0005
Validation Loss: 3.9169


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 0.9942
Validation Loss: 3.9611


100%|██████████| 241/241 [00:45<00:00,  5.32it/s]


Training Loss: 1.0057
Validation Loss: 4.0383


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 0.9949
Validation Loss: 3.9226


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 0.9867
Validation Loss: 4.0049


100%|██████████| 241/241 [00:45<00:00,  5.33it/s]


Training Loss: 1.0021
Validation Loss: 4.0340


100%|██████████| 241/241 [00:44<00:00,  5.36it/s]


Training Loss: 1.0067
Validation Loss: 4.1131


 28%|██▊       | 68/241 [00:13<00:33,  5.19it/s]


KeyboardInterrupt: 

In [18]:
torch.save(model.state_dict(), 'resnet50_4.pth')

In [19]:
model = resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 9)  # output layer for 9 classes
model.load_state_dict(torch.load('resnet50_4.pth'))
model = model.to(device)

In [20]:
class CellLineDataset(Dataset):
    def __init__(self, img_dir, labels_file=None, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.labels_df = pd.read_csv(labels_file) if labels_file else None
        if self.labels_df is not None:
            self.has_labels = True
            # Create a dictionary mapping class names to integers
            self.class_to_idx = {class_name: i for i, class_name in enumerate(
                self.labels_df["cell_line"].unique())}
        else:
            self.has_labels = False

    def __len__(self):
        if self.has_labels:
            return len(self.labels_df)
        else:
            # Here we'll need to compute the number of samples differently
            # If you have a list of test files, you can return its length
            return len(os.listdir(self.img_dir)) // 3  # assuming 3 images (red, blue, yellow) per sample


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        row = self.labels_df.iloc[idx] if self.has_labels else {'file_id': idx + 9633}

        sample_id = row['file_id']
        img_paths = [os.path.join(self.img_dir, f"{str(sample_id).zfill(5)}_{color}.png") for color in ["blue", "red", "yellow"]]
        imgs = [Image.open(img_path) for img_path in img_paths]
        img = torch.stack([torchvision.transforms.functional.to_tensor(im) for im in imgs]).squeeze(1)

        if self.transform:
            img = self.transform(img)

        if self.has_labels:
            # Convert label to integer
            label = self.class_to_idx[row['cell_line']]
            return img, label
        else:
            # When there's no labels, just return the image and file_id
            return img, sample_id



In [21]:
class_to_idx = train_dataset.class_to_idx
idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()}

# Create Dataset and DataLoader for test data
test_dataset = CellLineDataset(
    img_dir="images_test/images_test/", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [23]:
model.eval()
predictions = []
file_ids = []
with torch.no_grad():
    for inputs, file_id in tqdm(test_loader, desc='Predicting'):
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        predictions.extend([idx_to_class[pred.item()] for pred in preds])
        file_ids.extend([id.item() for id in file_id])


# Save predictions to a CSV file
df_predictions = pd.DataFrame({'file_id': file_ids, 'cell_line': predictions})
df_predictions.to_csv('predictions_3.csv', index=False)


Predicting: 100%|██████████| 215/215 [01:30<00:00,  2.37it/s]
