In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix, accuracy_score

In [2]:
train_df = pd.read_csv('train.csv')
dev_df = pd.read_csv('dev.csv')

## Preprocessing data

In [3]:
train_size = train_df.shape[0]
dev_size = dev_df.shape[0]

In [4]:
train_size, dev_size

(483, 207)

In [5]:
df = pd.concat([train_df, dev_df])

In [6]:
df['A16'] = df['A16'].replace({-1: 0})

In [7]:
# One hot encoding 
df[['A1', 'A4', 'A5', 'A6', 'A7', 'A9', 'A10', 'A12', 'A13']] = df[['A1', 'A4', 'A5', 'A6', 'A7', 'A9', 'A10', 'A12', 'A13']].astype(str)
one_hot = pd.get_dummies(df[['A1', 'A4', 'A5', 'A6', 'A7', 'A9', 'A10', 'A12', 'A13']])
df = df.drop(['A1', 'A4', 'A5', 'A6', 'A7', 'A9', 'A10', 'A12', 'A13'], axis=1)
df = pd.concat([one_hot, df], axis=1, sort=False)

In [8]:
df.head()

Unnamed: 0,A1_10.0,A1_11.0,A4_40.0,A4_41.0,A4_42.0,A5_50.0,A5_51.0,A5_52.0,A6_600.0,A6_601.0,...,A13_130.0,A13_131.0,A13_132.0,A2,A3,A8,A11,A14,A15,A16
0,1,0,0,1,0,0,1,0,0,0,...,1,0,0,21.67,1.165,2.5,1.0,180.0,20.0,0.0
1,1,0,0,1,0,0,1,0,0,0,...,1,0,0,23.58,0.46,2.625,6.0,208.0,347.0,0.0
2,0,1,1,0,0,1,0,0,1,0,...,1,0,0,47.75,8.0,7.875,6.0,0.0,1260.0,1.0
3,1,0,1,0,0,1,0,0,1,0,...,1,0,0,31.42,15.5,0.5,0.0,120.0,0.0,0.0
4,1,0,1,0,0,1,0,0,0,0,...,1,0,0,25.67,12.5,1.21,67.0,140.0,258.0,1.0


In [9]:
train_df = df.iloc[:train_size]
dev_df = df.iloc[train_size:]

assert train_df.shape[0] == train_size
assert dev_df.shape[0] == dev_size

In [10]:
min_max_scaler = preprocessing.MinMaxScaler()
min_max_scaler.fit(train_df[['A2', 'A3', 'A8', 'A11', 'A14', 'A15']])

MinMaxScaler(copy=True, feature_range=(0, 1))

In [11]:
train_df[['A2', 'A3', 'A8', 'A11', 'A14', 'A15']] = min_max_scaler.transform(train_df[['A2', 'A3', 'A8', 'A11', 'A14', 'A15']])
dev_df[['A2', 'A3', 'A8', 'A11', 'A14', 'A15']] = min_max_scaler.transform(dev_df[['A2', 'A3', 'A8', 'A11', 'A14', 'A15']].values)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[item] = s
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the cavea

In [12]:
X_train = train_df.drop('A16', axis=1).values
y_train = train_df['A16'].values


X_dev = dev_df.drop('A16', axis=1).values
y_dev = dev_df['A16'].values

## Model

In [13]:
print('Examples:{}    Features:{}'.format(X_train.shape[0], X_train.shape[1]))

Examples:483    Features:46


In [14]:
model = nn.Sequential(
    nn.Linear(X_train.shape[1], 32),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 16),
    nn.ReLU(),
    nn.Linear(16, 1),
    nn.Sigmoid()
    )

In [15]:
X_train = torch.Tensor(X_train)
y_train = torch.Tensor(y_train).float()

X_dev = torch.Tensor(X_dev)
y_dev = torch.Tensor(y_dev).float()

In [16]:
train_dataset = TensorDataset(X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=4)

dev_dataset = TensorDataset(X_dev, y_dev)
dev_dataloader = DataLoader(dev_dataset, batch_size=4) # create your dataloader

In [17]:
EPOCHS = 100
LR = 5e-5
CHECKPOINT = 'simple.pt'

In [18]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [19]:
best_dev_loss = 10e18

for epoch in range(EPOCHS):  # loop over the dataset multiple times

    ############## Train
    model.train()
    tr_loss = 0.0
    t = tqdm(enumerate(train_dataloader, 0), desc='Progress')
    for i, data in t:
        t.set_postfix({
            'Epoch': epoch + 1,
            'Batch': i + 1, 
            'Train loss': tr_loss / (i + 1)
        })
    
        inputs, labels = data

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs.flatten(), labels)
        loss.backward()
        optimizer.step()

        tr_loss += loss.item()
        
    ############## Validation
    model.eval()
    dev_loss = 0.0
    t = tqdm(enumerate(dev_dataloader, 0), desc='Progress')
    for i, data in t:
        t.set_postfix({
            'Epoch': epoch + 1,
            'Batch': i + 1, 
            'Dev loss': dev_loss / (i + 1)
        })
    
        inputs, labels = data
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs.flatten(), labels)
            dev_loss += loss.item()
    
    if dev_loss < best_dev_loss:
        best_dev_loss = dev_loss
        torch.save(model.state_dict(), CHECKPOINT)
        
print('Finished Training. Best dev loss: {}'.format(best_dev_loss / (len(dev_dataloader))))

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Progress', max=1.0, style=ProgressStyle…


Finished Training. Best dev loss: 0.3518047364285359


In [20]:
model.load_state_dict(torch.load(CHECKPOINT))

<All keys matched successfully>

In [21]:
model.eval()

with torch.no_grad():
    # Train
    y_train_pred = np.round(model(X_train))
    train_cm = confusion_matrix(y_train, y_train_pred)
    train_acc = accuracy_score(y_train, y_train_pred)
    
    # Validation
    y_dev_pred = np.round(model(X_dev))
    dev_cm = confusion_matrix(y_dev, y_dev_pred)
    dev_acc = accuracy_score(y_dev, y_dev_pred)

In [22]:
train_cm, train_acc

(array([[225,  43],
        [ 22, 193]]),
 0.865424430641822)

In [23]:
dev_cm, dev_acc

(array([[99, 16],
        [10, 82]]),
 0.8743961352657005)