In [4]:
import pandas as pd
import os
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image

In [2]:
# custom dataset class

class FlatImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        
        """
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]["filename"]
        label = self.df.iloc[idx]["label"]

        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [9]:
data = pd.read_csv("gz2_filename_mapping.csv") # uses objid
gz2_data = pd.read_csv("gz2_hart16.csv") # uses dr7objid

flag_labels = gz2_data[["t04_spiral_a08_spiral_flag", "dr7objid"]]

print(data.columns.tolist())
print(flag_labels.columns.tolist())

['objid', 'sample', 'asset_id']
['t04_spiral_a08_spiral_flag', 'dr7objid']


In [None]:
data = data.drop(columns=["sample", "asset_id"])

In [17]:
# dataset 

dataset = pd.merge(data, flag_labels, left_on="objid", right_on="dr7objid", how="inner")
dataset = dataset.drop(columns=["dr7objid"])

dataset.to_csv("dataset.csv", index=False)
print(dataset.columns.tolist())

['objid', 't04_spiral_a08_spiral_flag']


In [8]:
dataset.rename(columns={"t04_spiral_a08_spiral_flag": "label"})

                     objid    sample  asset_id  t04_spiral_a08_spiral_flag  \
0       587722981741363294  original         3                           0   
1       587722981741363323  original         4                           0   
2       587722981741559888  original         5                           0   
3       587722981741625481  original         6                           0   
4       587722981741625484  original         7                           0   
...                    ...       ...       ...                         ...   
239690  588015510368681992  stripe82    295294                           1   
239691  588015510368682105  stripe82    295295                           1   
239692  588015510368682132  stripe82    295296                           0   
239693  588015510636265643  stripe82    295304                           1   
239694  588015510636265731  stripe82    295305                           1   

                  dr7objid  
0       587722981741363294  
1    

In [5]:
# data

train_val_df, test_df = train_test_split(dataset, test_size=0.15, stratify=dataset["label"], random_state=26)
train_df, val_df = train_test_split(train_val_df, test_size=0.15/(0.85), stratify=train_val_df["label"], random_state=26)

train_df.to_csv("train_df.csv", index=False)
val_df.to_csv("val_df.csv", index=False)
test_df.to_csv("test_df.csv", index=False)

print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

KeyError: 'label'

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

In [None]:
train_dataset = FlatImageDataset("train.csv", "", transform=transform)
val_dataset = FlatImageDataset("val.csv", "", transform=transform)
test_dataset = FlatImageDataset("test.csv", "", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes, in_channel ,input_size):
        