In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix, accuracy_score

torch.manual_seed(42)

<torch._C.Generator at 0x7ff221071dd0>

In [2]:
train_df = pd.read_csv('n_train.csv')
dev_df = pd.read_csv('n_dev.csv')

## Preprocessing data

In [3]:
train_size = train_df.shape[0]
dev_size = dev_df.shape[0]

In [4]:
train_size, dev_size

(429, 185)

In [5]:
df = pd.concat([train_df, dev_df])

In [6]:
categorical_cols = ['Gender', 'Married', 'Dependents', 'Education', 'Self_Employed', 'Credit_History', 'Property_Area']
numerical_cols = ['ApplicantIncome', 'CoapplicantIncome', 'LoanAmount', 'Loan_Amount_Term']
label_col = 'Loan_Status'

In [7]:
df[label_col] = df[label_col].replace({-1: 0})

In [8]:
# One hot encoding 
df[categorical_cols] = df[categorical_cols].astype(str)
one_hot = pd.get_dummies(df[categorical_cols])
df = df.drop(categorical_cols, axis=1)
df = pd.concat([one_hot, df], axis=1, sort=False)

In [9]:
df.head()

Unnamed: 0,Gender_Female,Gender_Male,Married_No,Married_Yes,Dependents_0,Dependents_1,Dependents_2,Dependents_3+,Education_Graduate,Education_Not Graduate,...,Credit_History_0.0,Credit_History_1.0,Property_Area_Rural,Property_Area_Semiurban,Property_Area_Urban,ApplicantIncome,CoapplicantIncome,LoanAmount,Loan_Amount_Term,Loan_Status
0,0,1,0,1,0,0,1,0,1,0,...,0,1,0,0,1,2500.0,1840.0,109.0,360.0,1.0
1,0,1,1,0,1,0,0,0,1,0,...,0,1,0,1,0,5941.0,4232.0,296.0,360.0,1.0
2,0,1,0,1,0,0,0,1,0,1,...,1,0,0,1,0,4931.0,0.0,128.0,360.0,0.0
3,0,1,0,1,1,0,0,0,0,1,...,0,1,1,0,0,2894.0,2792.0,155.0,360.0,1.0
4,0,1,0,1,1,0,0,0,1,0,...,0,1,0,0,1,2500.0,3796.0,120.0,360.0,1.0


In [10]:
train_df = df.iloc[:train_size]
dev_df = df.iloc[train_size:]

assert train_df.shape[0] == train_size
assert dev_df.shape[0] == dev_size

In [11]:
min_max_scaler = preprocessing.MinMaxScaler()
min_max_scaler.fit(train_df[numerical_cols])

MinMaxScaler(copy=True, feature_range=(0, 1))

In [12]:
train_df[numerical_cols] = min_max_scaler.transform(train_df[numerical_cols])
dev_df[numerical_cols] = min_max_scaler.transform(dev_df[numerical_cols])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[item] = s
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the cavea

In [13]:
X_train = train_df.drop(label_col, axis=1).values
y_train = train_df[label_col].values

X_dev = dev_df.drop(label_col, axis=1).values
y_dev = dev_df[label_col].values

## Model

In [14]:
print('Examples:{}    Features:{}'.format(X_train.shape[0], X_train.shape[1]))

Examples:429    Features:21


In [15]:
model = nn.Sequential(
    nn.Linear(X_train.shape[1], 32),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 16),
    nn.ReLU(),
    nn.Linear(16, 1),
    nn.Sigmoid()
    )

In [16]:
X_train = torch.Tensor(X_train)
y_train = torch.Tensor(y_train).float()

X_dev = torch.Tensor(X_dev)
y_dev = torch.Tensor(y_dev).float()

In [17]:
train_dataset = TensorDataset(X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=4)

dev_dataset = TensorDataset(X_dev, y_dev)
dev_dataloader = DataLoader(dev_dataset, batch_size=4) # create your dataloader

In [18]:
EPOCHS = 100
LR = 1e-4
CHECKPOINT = 'simple.pt'

In [19]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [20]:
best_dev_loss = 10e18

for epoch in range(EPOCHS):  # loop over the dataset multiple times

    ############## Train
    model.train()
    tr_loss = 0.0
    t = tqdm(enumerate(train_dataloader, 0), desc='Progress')
    for i, data in t:
        t.set_postfix({
            'Epoch': epoch + 1,
            'Batch': i + 1, 
            'Train loss': tr_loss / (i + 1)
        })
    
        inputs, labels = data

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs.flatten(), labels)
        loss.backward()
        optimizer.step()

        tr_loss += loss.item()
        
    ############## Validation
    model.eval()
    dev_loss = 0.0
    t = tqdm(enumerate(dev_dataloader, 0), desc='Progress')
    for i, data in t:
        t.set_postfix({
            'Epoch': epoch + 1,
            'Batch': i + 1, 
            'Dev loss': dev_loss / (i + 1)
        })
    
        inputs, labels = data
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs.flatten(), labels)
            dev_loss += loss.item()
    
    if dev_loss < best_dev_loss:
        best_dev_loss = dev_loss
        torch.save(model.state_dict(), CHECKPOINT)
        
print('Finished Training. Best dev loss: {}'.format(best_dev_loss / (len(dev_dataloader))))

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…


Finished Training. Best dev loss: 0.44922810317354


In [21]:
model.load_state_dict(torch.load(CHECKPOINT))

<All keys matched successfully>

In [22]:
model.eval()

with torch.no_grad():
    # Train
    y_train_pred = np.round(model(X_train))
    train_cm = confusion_matrix(y_train, y_train_pred)
    train_acc = accuracy_score(y_train, y_train_pred)
    
    # Validation
    y_dev_pred = np.round(model(X_dev))
    dev_cm = confusion_matrix(y_dev, y_dev_pred)
    dev_acc = accuracy_score(y_dev, y_dev_pred)

In [23]:
train_cm, train_acc

(array([[ 59,  75],
        [  7, 288]]),
 0.8088578088578089)

In [24]:
dev_cm, dev_acc

(array([[ 28,  30],
        [  4, 123]]),
 0.8162162162162162)