In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os

class MammographyDataset(Dataset):

	def __init__(self, csv_file, csv_image_col, img_dir, transform=None):
		self.data = pd.read_csv(csv_file)
		self.data_image_col = csv_image_col
		self.img_dir = img_dir
		self.transform = transform

	def __len__(self):
		return len(self.data)

	# def __getitem__(self, case_id):
	# 	# retrieves a batch of (4) items
	# 	img_list = []
	# 	for ip in self.data[self.data["case_id"] == case_id]["image_id"]:
	# 		x = Image.open(os.path.join(self.img_dir, f'{case_id}/{ip}'))
	
	# 		if self.transform:
	# 			x = self.transform(x)

	# 		img_list.append(x)
        
	# 	return img_list
	

	def __getitem__(self, idx):
		row = 
		img_path = os.path.join(self.img_dir, self.data[self.data_image_col].iloc[idx])
		image = Image.open(img_path).convert("RGB")
        
		if self.transform:
			image = self.transform(image)
        
		return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = MammographyDataset(csv_file='classification.csv', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}')

print('Training complete')

KeyboardInterrupt: 