In [1]:
# Copying Code from main.py

In [1]:
import torch
import sklearn
import numpy as np
import pandas as pd
from robust_losses import RobustLoss
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import make_column_transformer
from torch.utils.data import Dataset
from glob import glob

import os

import warnings
warnings.filterwarnings('ignore')

In [2]:
class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
        self.x = x
        self.y = y


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [3]:
class Logistic_Reg_model(torch.nn.Module):
    def __init__(self,no_input_features):
        super(Logistic_Reg_model,self).__init__()
        self.layer1=torch.nn.Linear(no_input_features,64)
        self.layer2=torch.nn.Linear(64,1)
    def forward(self,x):
        y_predicted=self.layer1(x)
        y_predicted=torch.sigmoid(self.layer2(y_predicted))
        return y_predicted

In [4]:
def testaccuracy():
    with torch.no_grad():
        y_pred=model(x_test)
        y_pred_class=y_pred.round()
        accuracy=(y_pred_class.eq(y_test).sum())/float(y_test.shape[0])
        return (accuracy.item())

In [5]:
def saveModel():
    fname = "DRO_model_lsac.pth"
    torch.save(model.state_dict(), fname)

In [6]:
path = '../data/datasets/law_school/'
synthfols = glob("../data/datasets/law_school/synthetic/*/")
paths = [path]
paths.extend(synthfols)

In [None]:
for path in paths:
    
#     if os.path.exists(path+'preds/DRO_pred.pt') == False:

    train_df = pd.read_csv(path+'train.csv',header=None)

    train_df.columns = ['zfygpa', 'zgpa', 'DOB_yr', 'weighted_lsat_ugpa', 'cluster_tier',
       'family_income', 'lsat', 'ugpa', 'isPartTime', 'sex', 'race',
       'pass_bar']

#     train_df.index = train_df['SEX']

    train_df = train_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

    test_df = pd.read_csv(path+'test.csv',header=None)

    test_df.columns = ['zfygpa', 'zgpa', 'DOB_yr', 'weighted_lsat_ugpa', 'cluster_tier',
       'family_income', 'lsat', 'ugpa', 'isPartTime', 'sex', 'race',
       'pass_bar']
    
#     test_df = test_df.drop(['SEX'],axis=1)

#     test_df.index = test_df['SEX']

    test_df = test_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)


    x_train = train_df.drop(['pass_bar','sex'],axis=1)
    x_test = test_df.drop(['pass_bar','sex'],axis=1)

    x_merged = pd.concat([x_train,x_test])

    ohe = make_column_transformer(
        (OneHotEncoder(sparse=False), x_merged.dtypes == 'object'),
        remainder='passthrough', verbose_feature_names_out=False)

    x_merged_temp  = pd.DataFrame(ohe.fit_transform(x_merged), columns=ohe.get_feature_names_out(), index=x_merged.index)

    x_train  = pd.DataFrame(ohe.transform(x_train), columns=ohe.get_feature_names_out(), index=x_train.index)
    x_test = pd.DataFrame(ohe.transform(x_test), columns=ohe.get_feature_names_out(), index=x_test.index)

    y_train = pd.Series(train_df['pass_bar'])
    y_test = pd.Series(test_df['pass_bar'])

    y_train = pd.Series(y_train.factorize(sort=True)[0], index=y_train.index)
    y_test = pd.Series(y_test.factorize(sort=True)[0], index=y_test.index)

    x_train=torch.from_numpy(x_train.to_numpy().astype(np.float32))
    x_test=torch.from_numpy(x_test.to_numpy().astype(np.float32))
    y_train=torch.from_numpy(y_train.to_numpy().astype(np.float32))
    y_test=torch.from_numpy(y_test.to_numpy().astype(np.float32))

    y_train=y_train.view(y_train.shape[0],1)
    y_test=y_test.view(y_test.shape[0],1)

    traindata = MyDataset(x_train, y_train)

    trainloader = torch.utils.data.DataLoader(traindata, batch_size=1000, shuffle=True)

    n_features = x_train.shape[1]
    model=Logistic_Reg_model(n_features)

    criterion=torch.nn.BCELoss(reduction='none')
    robust_loss = RobustLoss(geometry='chi-square', size=1.0, reg=0.5)
    optimizer=torch.optim.Adam(model.parameters())#,lr=0.001, weight_decay=0.0001)

    number_of_epochs=100
    best_accuracy = 0.0

    for epoch in range(number_of_epochs):
        running_loss = 0.0
        for i, (x_b, y_b) in enumerate(trainloader, 0):
            optimizer.zero_grad()
            y_prediction=model(x_b)
            loss=robust_loss(criterion(y_prediction.squeeze(),y_b.squeeze()))
            loss.backward()
            optimizer.step()
            running_loss += loss.item() 
            if (i)%10 == 9:
                print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 10))
                running_loss = 0.0
        accuracy = testaccuracy()
        print('accuracy:', accuracy)
        print('best:', best_accuracy)
        if accuracy > best_accuracy:
            saveModel()
            best_accuracy = accuracy

    finalmodel = Logistic_Reg_model(n_features)
    finalmodel.load_state_dict(torch.load('DRO_model_lsac.pth'))

    with torch.no_grad():
        y_pred=finalmodel(x_test)
        y_pred_class=y_pred.round()
        try:
            os.mkdir(path+'preds/')
        except:
            pass
        torch.save(y_pred_class,path+'preds/OG_DRO_pred.pt')


[1,    10] loss: 99.929
[1,    20] loss: 99.924
accuracy: 0.19676144421100616
best: 0.0
[2,    10] loss: 99.926
[2,    20] loss: 99.928
accuracy: 0.19676144421100616
best: 0.19676144421100616
[3,    10] loss: 99.927
[3,    20] loss: 99.926
accuracy: 0.19676144421100616
best: 0.19676144421100616
[4,    10] loss: 99.928
[4,    20] loss: 99.927
accuracy: 0.19676144421100616
best: 0.19676144421100616
[5,    10] loss: 99.929
[5,    20] loss: 99.924
accuracy: 0.19676144421100616
best: 0.19676144421100616
[6,    10] loss: 99.922
[6,    20] loss: 99.930
accuracy: 0.19676144421100616
best: 0.19676144421100616
[7,    10] loss: 99.926
[7,    20] loss: 99.929
accuracy: 0.19676144421100616
best: 0.19676144421100616
[8,    10] loss: 99.927
[8,    20] loss: 99.929
accuracy: 0.19676144421100616
best: 0.19676144421100616
[9,    10] loss: 99.927
[9,    20] loss: 99.926
accuracy: 0.19676144421100616
best: 0.19676144421100616
[10,    10] loss: 99.927
[10,    20] loss: 99.926
accuracy: 0.19676144421100616
