In [28]:
import pandas as pd
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import Counter
import scipy
from scipy import stats
import sklearn

In [29]:
labels = pd.read_csv('/data/kpusteln/Fetal-RL/data_preparation/outputs/labels_corrected.csv')

In [30]:
labels

Unnamed: 0,index,Class,video
0,1_1_1,1,1_1
1,1_1_2,1,1_1
2,1_1_3,1,1_1
3,1_1_4,1,1_1
4,1_1_5,1,1_1
...,...,...,...
261760,708_3_75,5,708_3
261761,708_3_76,5,708_3
261762,708_3_77,5,708_3
261763,708_3_78,5,708_3


In [31]:
## Data split
def probability_mass(data):
    
    counts = Counter(data) # counting the classes
    total = sum(counts.values()) # total number of classes
    probability_mass = {k:v/total for k,v in counts.items()} # probability mass of the classes
    probability_mass = list(probability_mass.values()) # converting the dictionary to a list
    return probability_mass
    

def train_test_split(data, train_size = 0.9, precision = 0.005):
    """splitting data into train and test sets keeping the same distribution of classes using wasertein's method
    args: data - data frame containing the data
    train_size - size of the train set default
    precision - determines how close the train set size is to the train_size default 0.005 (the smaller the better, but it may take longer to generate sets)"""
    
    print('Splitting data into train and test sets...')
    
    #data = pd.read_csv(data) # loading the data
    wass_dist = 1
    videos = list(data['video'].unique()) # list of videos
    train_size = int(train_size * len(videos)) # calculating the number of videos in the train set
    while wass_dist > precision: # while the wasserstein distance is greater than 0.005
        train = random.sample(videos, train_size) # sampling the train set
        testval = [x for x in videos if x not in train] # sampling the test set
        train_set = data.loc[data['video'].isin(train)] # creating the train set
        testval_set = data.loc[data['video'].isin(testval)] # creating the test set
        test_set = testval_set.sample(frac = 0.5) # sampling the test set
        val_set = testval_set.drop(test_set.index) # sampling the validation set
        probability_mass_train = probability_mass(train_set['Class']) # calculating the probability mass of the train set
        probability_mass_test = probability_mass(test_set['Class']) # calculating the probability mass of the test set
        wass_dist = scipy.stats.wasserstein_distance(probability_mass_train, probability_mass_test) # wasserstein distance between distributions
    print('Done!')
    return train_set, val_set, test_set
    
def histogram_class_plot(data, title):
    """plotting histogram of the classes"""
    plt.figure(figsize=(10,10))
    plt.hist(data['Class'], bins = 100)
    plt.xlabel('Class')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.savefig(title + '.png')






In [32]:
train_set, val_set, test_set = train_test_split(labels, train_size = 0.9, precision = 0.005)

Splitting data into train and test sets...
Done!


In [None]:
"""new_labels:
0 - other
1 - head non-standard plane
2 - head standard plane
3 - abdomen non-standard plane
4 - abdomen standard plane
5 - femur non standard plane
6 - femur standard plane"""

In [33]:
train_set['Class'].value_counts()

3    96947
5    49773
6    35551
0    31740
1    14138
4     4442
2     2705
Name: Class, dtype: int64

In [34]:
# train_set['Class'] = train_set['Class'].replace([3,5], [1,1])
# val_set['Class'] = val_set['Class'].replace([3,5], [1,1])
# test_set['Class'] = test_set['Class'].replace([3,5], [1,1])

# train_set['Class'] = train_set['Class'].replace([4,6], [2,2])
# val_set['Class'] = val_set['Class'].replace([4,6], [2,2])
# test_set['Class'] = test_set['Class'].replace([4,6], [2,2])

In [35]:
val_set['Class'].value_counts()

3    5611
5    2605
6    2019
0    1883
1     761
4     259
2      97
Name: Class, dtype: int64

In [36]:
val_set['video'].iloc[0].split('_')[1]

'1'

In [37]:
train_set.to_csv('/data/kpusteln/fetal/standard_plane/class_data/train_set7.csv', index = False)
val_set.to_csv('/data/kpusteln/fetal/standard_plane/class_data/val_set7.csv', index = False)
test_set.to_csv('/data/kpusteln/fetal/standard_plane/class_data/test_set7.csv', index = False)

In [38]:
from torchvision.models import efficientnet_v2_l

In [39]:
model = efficientnet_v2_l()

In [40]:
model.classifier

Sequential(
  (0): Dropout(p=0.4, inplace=True)
  (1): Linear(in_features=1280, out_features=1000, bias=True)
)