In [None]:
import pandas as pd
import numpy as np
import ast

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn, optim
from torchvision import datasets, utils, models
# from torchinfo import summary
import torch.nn.functional as F
from torchvision.transforms import v2

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import os
from tqdm import tqdm
import altair as alt
alt.data_transformers.enable("vegafusion")


## Data Cleaning

In [None]:
# For small dataset
# labels = pd.read_csv('data/small_data/labels.csv', index_col=0)
# print(labels.shape)
# labels.head()

In [None]:
# for full dataset
folder_path = 'data/labels/'

# List to hold DataFrames
dataframes = []

# Loop through all files in the directory
for filename in os.listdir(folder_path):
    if filename.endswith('.csv'):
        # Full path to the file
        file_path = os.path.join(folder_path, filename)
        # Read the CSV file and append to the list
        df = pd.read_csv(file_path, index_col=0)
        dataframes.append(df)

# Concatenate all the DataFrames in the list
labels = pd.concat(dataframes, ignore_index=True)
print(labels.shape)
labels.head()

In [None]:
clean_img_code = []
for img_code in labels['index']:
    try:
        Image.open(f"data/img/{img_code}.png")
        clean_img_code.append(img_code)
    except:
        continue
print(f'{len(clean_img_code)} rows found corresponding image')
labels = labels[labels['index'].isin(clean_img_code)]
labels.reset_index(inplace=True, drop=True)
print(labels.shape)

In [None]:
# 'Normal' if there is no genes
labels.loc[labels["genes"] == "[]", "genes"] = '["Normal"]'

Remove heterozygous genes

In [None]:
clean_genes = []
list_genes = [ast.literal_eval(gene) for gene in labels['genes']]

list_genes_no_het = [
    [item for item in sublist if 'het' not in item.lower()]
    for sublist in list_genes
]

list_genes_no_het = [['Normal'] if not sublist else sublist for sublist in list_genes_no_het]

labels['genes'] = list_genes_no_het

for lst in list_genes_no_het:
    for element in lst:
        clean_genes.append(element)


clean_possible_genes = list(set(clean_genes))
print(f'Number of possible genes: {len(clean_possible_genes)}')
clean_possible_genes[:5]

In [None]:
gene_extension_df = pd.DataFrame(np.zeros([labels.shape[0], len(clean_possible_genes)]), dtype=int, columns=clean_possible_genes)
labels_extended = pd.concat([labels, gene_extension_df], axis=1)
print(labels_extended.shape)
labels_extended.head()

In [None]:
len([len(lst) for lst in list_genes])

In [None]:
len(labels_extended)

In [None]:
num_genes_not_het = [len(lst) for lst in list_genes_no_het]
start_row = 0
count = 0
for gene_col in clean_genes:
    labels_extended.loc[start_row, gene_col] = 1
    count += 1
    if count == num_genes_not_het[start_row]:
        start_row += 1
        count = 0
print(labels_extended.shape)
labels_extended.head()

In [None]:
assert list(labels_extended.select_dtypes('int').sum(axis=1)) == [len(lst) for lst in list_genes_no_het]