In [1]:
# Copying Code from main.py

In [None]:
import torch
import sklearn
import numpy as np
import pandas as pd
from robust_losses import RobustLoss
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import make_column_transformer
from torch.utils.data import Dataset
from glob import glob

import os
import traceback

import warnings
warnings.filterwarnings('ignore')

In [2]:
class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
        self.x = x
        self.y = y


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [3]:
class Logistic_Reg_model(torch.nn.Module):
    def __init__(self,no_input_features):
        super(Logistic_Reg_model,self).__init__()
        self.layer1=torch.nn.Linear(no_input_features,64)
        self.layer2=torch.nn.Linear(64,1)
    def forward(self,x):
        y_predicted=self.layer1(x)
        y_predicted=torch.sigmoid(self.layer2(y_predicted))
        return y_predicted

In [4]:
def testaccuracy(model,x_test,y_test):
    with torch.no_grad():
        y_pred=model(x_test)
        y_pred_class=y_pred.round()
        accuracy=(y_pred_class.eq(y_test).sum())/float(y_test.shape[0])
        return (accuracy.item())

In [5]:
def saveModel(model):
    fname = "DRO_model.pth"
    torch.save(model.state_dict(), fname)

In [6]:
paths  = ['../data/datasets/publiccov_ca/', '../data/datasets/employment_ca/', '../data/datasets/law_school/', '../data/datasets/diabetes/']
cnames = [['AGEP', 'SCHL', 'MAR', 'SEX', 'DIS', 'ESP', 'CIT', 'MIG', 'MIL', 'ANC',
       'NATIVITY', 'DEAR', 'DEYE', 'DREM', 'PINCP', 'ESR', 'FER', 'RAC1P',
       'PUBCOV'],
          ['AGEP', 'SCHL', 'MAR', 'RELP', 'DIS', 'ESP', 'CIT', 'MIG', 'MIL', 'ANC',
       'NATIVITY', 'DEAR', 'DEYE', 'DREM', 'SEX', 'RAC1P', 'ESR'],
          ['zfygpa', 'zgpa', 'DOB_yr', 'weighted_lsat_ugpa', 'cluster_tier',
       'family_income', 'lsat', 'ugpa', 'isPartTime', 'sex', 'race',
       'pass_bar'],
          ['race', 'sex', 'age', 'admissiontypeid', 'dischargedispositionid',
       'admissionsourceid', 'timeinhospital', 'numlabprocedures',
       'numprocedures', 'nummedications', 'numberoutpatient',
       'numberemergency', 'numberinpatient', 'diag1', 'diag2', 'diag3',
       'numberdiagnoses', 'maxgluserum', 'A1Cresult', 'metformin',
       'glimepiride', 'glipizide', 'glyburide', 'pioglitazone',
       'rosiglitazone', 'insulin', 'change', 'diabetesMed', 'readmitted']
]

In [7]:
def task(path,cols,num):
    
    print(path,num)
    
    ss = ""
    if 'SEX' in cols:
        ss = 'SEX'
    else:
        ss = 'sex'
        
    tgt = cols[-1]

    train_df = pd.read_csv(path+'train.csv',header=None)

    train_df.columns = cols


    train_df = train_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

    test_df = pd.read_csv(path+'test.csv',header=None)

    test_df.columns = cols

    test_df = test_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)


    x_train = train_df.drop([tgt,ss],axis=1)
    x_test = test_df.drop([tgt,ss],axis=1)

    x_merged = pd.concat([x_train,x_test])

    ohe = make_column_transformer(
        (OneHotEncoder(sparse=False), x_merged.dtypes == 'object'),
        remainder='passthrough', verbose_feature_names_out=False)

    x_merged_temp  = pd.DataFrame(ohe.fit_transform(x_merged), columns=ohe.get_feature_names_out(), index=x_merged.index)

    x_train  = pd.DataFrame(ohe.transform(x_train), columns=ohe.get_feature_names_out(), index=x_train.index)
    x_test = pd.DataFrame(ohe.transform(x_test), columns=ohe.get_feature_names_out(), index=x_test.index)

    y_train = pd.Series(train_df[tgt])
    y_test = pd.Series(test_df[tgt])

    y_train = pd.Series(y_train.factorize(sort=True)[0], index=y_train.index)
    y_test = pd.Series(y_test.factorize(sort=True)[0], index=y_test.index)

    x_train=torch.from_numpy(x_train.to_numpy().astype(np.float32))
    x_test=torch.from_numpy(x_test.to_numpy().astype(np.float32))
    y_train=torch.from_numpy(y_train.to_numpy().astype(np.float32))
    y_test=torch.from_numpy(y_test.to_numpy().astype(np.float32))

    y_train=y_train.view(y_train.shape[0],1)
    y_test=y_test.view(y_test.shape[0],1)

    traindata = MyDataset(x_train, y_train)

    trainloader = torch.utils.data.DataLoader(traindata, batch_size=1000, shuffle=True)

    n_features = x_train.shape[1]
    model=Logistic_Reg_model(n_features)

    criterion=torch.nn.BCELoss(reduction='none')
    robust_loss = RobustLoss(geometry='chi-square', size=1.0, reg=0.5)
    optimizer=torch.optim.Adam(model.parameters())#,lr=0.001, weight_decay=0.0001)

    number_of_epochs=100
    best_accuracy = 0.0

    for epoch in range(number_of_epochs):
        running_loss = 0.0
        for i, (x_b, y_b) in enumerate(trainloader, 0):
            optimizer.zero_grad()
            y_prediction=model(x_b)
            loss=robust_loss(criterion(y_prediction.squeeze(),y_b.squeeze()))
            loss.backward()
            optimizer.step()
            running_loss += loss.item() 
            if (i)%10 == 9:
                print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 10))
                running_loss = 0.0
        accuracy = testaccuracy(model,x_test,y_test)
        print('accuracy:', accuracy)
        print('best:', best_accuracy)
        if accuracy > best_accuracy:
            saveModel(model)
            best_accuracy = accuracy

    finalmodel = Logistic_Reg_model(n_features)
    finalmodel.load_state_dict(torch.load('DRO_model.pth'))

    with torch.no_grad():
        y_pred=finalmodel(x_test)
        y_pred_class=y_pred.round()
        try:
            os.mkdir(path+'preds/')
        except:
            pass
        torch.save(y_pred_class,path+'preds/OG_DRO_pred_'+str(num)+'.pt')


In [None]:
for path,cols in zip(paths,cnames):
    for i in range(10):
        try:
            task(path,cols,i)
        except:
            traceback.print_exc()

../data/datasets/publiccov_ca/ 0
[1,    10] loss: 64.281
[1,    20] loss: 80.615
[1,    30] loss: 81.174
[1,    40] loss: 81.030
[1,    50] loss: 80.097
[1,    60] loss: 82.099
[1,    70] loss: 81.351
[1,    80] loss: 80.831
[1,    90] loss: 80.907
[1,   100] loss: 82.435
[1,   110] loss: 82.422
accuracy: 0.627801239490509
best: 0.0
[2,    10] loss: 83.568
[2,    20] loss: 81.078
[2,    30] loss: 80.252
[2,    40] loss: 81.581
[2,    50] loss: 80.803
[2,    60] loss: 82.603
[2,    70] loss: 81.042
[2,    80] loss: 79.316
[2,    90] loss: 82.614
[2,   100] loss: 79.892
[2,   110] loss: 81.588
accuracy: 0.6341525316238403
best: 0.627801239490509
[3,    10] loss: 81.025
[3,    20] loss: 80.999
[3,    30] loss: 80.590
[3,    40] loss: 80.239
[3,    50] loss: 81.857
[3,    60] loss: 80.920
[3,    70] loss: 81.992
[3,    80] loss: 81.640
[3,    90] loss: 79.726
[3,   100] loss: 81.584
[3,   110] loss: 80.886
accuracy: 0.6340803503990173
best: 0.6341525316238403
[4,    10] loss: 81.235
[4,   