In [23]:
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import \
    train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, f1_score, confusion_matrix
from sklearn.feature_extraction.text import CountVectorizer
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler, Dataset
import random

In [24]:
# Load dataset
def load_dataset(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

# Convert sequence of integers to string
def int_list_to_str(int_list):
    return ' '.join(map(str, int_list))

# Load data
set1_human = load_dataset("/kaggle/input/sml-data/data/set1_human.json")
set1_machine = load_dataset("/kaggle/input/sml-data/data/set1_machine.json")
set2_human = load_dataset("/kaggle/input/sml-data/data/set2_human.json")
set2_machine = load_dataset("/kaggle/input/sml-data/data/set2_machine.json")
set_test = load_dataset('/kaggle/input/sml-data/data/test.json')

print(type(set_test))


for d in set1_human:
    d['label'] = 1.0
for d in set1_machine:
    d['label'] = 0.0
for d in set2_human:
    d['label'] = 1.0
for d in set2_machine:
    d['label'] = 0.0

human_labels = np.ones(len(set1_human) + len(set2_human))
machine_labels = np.ones(len(set1_machine) + len(set2_machine))

human_labels2 = np.ones(len(set2_human))
machine_labels2 = np.ones(len(set2_machine))

data = set1_human + set2_human + set1_machine + set2_machine
labels = np.concatenate([human_labels, machine_labels])


data2 = set2_machine + set2_human
labels2 = np.concatenate([human_labels2, machine_labels2])

# # mock test 
set1_human = pd.DataFrame(set1_human)
set2_human = pd.DataFrame(set2_human)
set1_machine = pd.DataFrame(set1_machine)
set2_machine = pd.DataFrame(set2_machine)

# Train/validation split
train_data1, val_data1 = train_test_split(data, test_size=0.2, random_state=42, stratify=labels)
train_data2, val_data2 = train_test_split(data2, test_size=0.2, random_state=42, stratify=labels2)

print(len(train_data1))
print(len(val_data1))

<class 'list'>

Mock test sets
101267
25317


In [25]:
max_len = 256

class PaddingDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['txt'][:max_len]
        text_len = len(text)
        text = torch.tensor(text)
        if text_len < max_len:
            text = torch.nn.functional.pad(text, (0, max_len - text_len), "constant", 0)
        label = torch.tensor(self.data[idx]['label'])
        return text, label
    
class TestDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['txt'][:max_len]
        text_len = len(text)
        text = torch.tensor(text)
        if text_len < max_len:
            text = torch.nn.functional.pad(text, (0, max_len - text_len), "constant", 0)
#         label = torch.tensor(self.data[idx]['label'])
        return text

In [26]:
batch_size = 128
train1 = PaddingDataset(train_data1)
train2 = PaddingDataset(train_data2)
val1 = PaddingDataset(val_data1)
val2 = PaddingDataset(val_data2)
test1 = TestDataset(set_test[:600])
test2 = TestDataset(set_test[600:])
mock_test = PaddingDataset(mock_test_data)


# create sampler
class_weights_1 = [34.85,1]
sample_weights_1 = [0] * len(train1)
for i, (data, label) in enumerate(train1):
    label = label.int()
    class_weight = class_weights_1[label]
    sample_weights_1[i] = class_weight
sampler1 = WeightedRandomSampler(sample_weights_1, num_samples=len(sample_weights_1),replacement=True)

class_weights_2 = [1,4]
sample_weights_2 = [0] * len(train2)
for i, (data, label) in enumerate(train2):
    label = label.int()
    class_weight = class_weights_2[label]
    sample_weights_2[i] = class_weight
sampler2 = WeightedRandomSampler(sample_weights_2, num_samples=len(sample_weights_2),replacement=True)

#
train_loader1 = DataLoader(train1, batch_size=batch_size,sampler=sampler1) # or use sampler
train_loader2 = DataLoader(train2, batch_size=batch_size,sampler=sampler2)
valid_loader1 = DataLoader(val1,batch_size=batch_size,shuffle=False)
valid_loader2 = DataLoader(val2, batch_size=batch_size, shuffle=False)
test_loader1 = DataLoader(test1,batch_size=1)
test_loader2 = DataLoader(test2, batch_size=1)
X_mock_test = DataLoader(mock_test,batch_size=1)



In [38]:


# vocab_size = 5000
input_size = 256
embedding_dim = 128

device = torch.device('cuda')
print(device)



class RNNmodel(nn.Module):
    def __init__(self, input_size,hidden_size,output_size,num_layers,embedding_dim,batch_size):
        super(RNNmodel, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.embedding_dim=embedding_dim
        self.batch_size = batch_size
        self.embedding = nn.Embedding(self.input_size*batch_size,embedding_dim)
        self.maxpool = nn.MaxPool1d(4)
        self.LSTM = nn.LSTM(embedding_dim,self.hidden_size,self.num_layers,batch_first=True,dropout=0.2)
        self.fc = nn.Linear(hidden_dim//4,output_dim)

    def forward(self, x):
        embedded = self.embedding(x)
        out, (hidden, cell) = self.LSTM(embedded)
        pooled = self.maxpool(out)
        out = self.fc(pooled[:, -1, :])
        # print(out.shape)
        out = torch.sigmoid(out)
        return out

batch_size = 128
num_epochs = 15


# Create RNN
hidden_dim = 256
layer_dim = 1
output_dim = 1

model = RNNmodel(
    input_size=input_size,
    hidden_size=hidden_dim,
    output_size=output_dim,
    embedding_dim=embedding_dim,
    batch_size=batch_size,
    num_layers=2
).to(device)


cuda


In [30]:
class EarlyStopping:
    def __init__(self, tolerance=2, min_delta=0):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:
                self.early_stop = True

early_stopping = EarlyStopping(tolerance=10, min_delta=0.05)


train_loss = []
error = nn.BCEWithLogitsLoss()


from tqdm import tqdm
def do_validation(valid_loader,model):
    iteration = 0
    correct = 0
    total = 0
    loss_list = []
    acc_list = []
    pred_list = []
    true_list = []
    # Iterate through test dataset
    for e, l in valid_loader:
        train_loss = []
        # Forward propagation
        e=e.to(device).long()
        # print(e.shape)
        # print(l.shape)
        l=l.to(device)
        l = l.unsqueeze(1)
        logits_val = model(e)
        loss_val = error(logits_val, l)
        loss_list.append(loss_val.tolist())
        predicted_labels = (logits_val>0.5).float()
        pred_list += predicted_labels.int().tolist()
        true_list += l.int().tolist()
        acc_list.append(accuracy)
        train_loss.append(loss_val.tolist())
        iteration += 1
        if iteration % 100:
            loss_list.append(sum(train_loss)/len(train_loss))
            train_loss = []
    f1 = f1_score(pred_list, true_list)
    print('{}'.format(confusion_matrix(pred_list, true_list)))
    print('f1 score : {}'.format(f1))
    loss = sum(loss_list)/len(loss_list)
    print('validation loss: {}'.format(loss))
    return loss_list,f1


In [12]:
do_validation(valid_loader2,model)

[[21  0]
 [58 21]]
f1 score : 0.41999999999999993


([0.9662222266197205, 0.9662222266197205], 0.41999999999999993)

In [31]:
# train
# optimizer
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
error = nn.BCEWithLogitsLoss()

iteration_list = []
validation_loss1 = []
validation_loss2 = []
iteration = 0
loss_list = []
max_f1 = 0

for epoch in range(num_epochs):
    model.train()
    loop = tqdm(enumerate(train_loader1), total=len(train_loader1), leave=False)
    for i, (elem, labels) in loop:
        train_loss = []
        elem=elem.to(device)
        labels = Variable(torch.Tensor(labels)).to(device)
        # Clear gradient
        optimizer.zero_grad()
        # Forward
        logits = model(elem)
        labels = labels.unsqueeze(1).to(device)
        # Calculate cross entropy loss
        predicted_labels = (logits>0.5).float()
        loss = error(logits, labels)
        # Calculating gradient
        correct_train = (predicted_labels == labels).float().sum()
        accuracy = correct_train/len(predicted_labels)
        loss.backward()
        # Update parameters
        optimizer.step()
        loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.data)
        train_loss.append(loss.tolist())
        iteration += 1
        if iteration % 100:
            loss_list.append(sum(train_loss)/len(train_loss))
            train_loss = []
    val_loss1, val1_f1 = do_validation(valid_loader1,model)
    validation_loss1 += val_loss1
    val_loss2, val2_f1 = do_validation(valid_loader2,model)
    validation_loss2 += val_loss2
    print('train loss: {}'.format(epoch_train_loss))
    early_stopping(epoch_train_loss, epoch_validate_loss1)
    early_stopping(epoch_train_loss, epoch_validate_loss2)
    if early_stopping.early_stop:
        print("We are at epoch:", epoch)
        break



                                                                                                     

[[  703  6509]
 [   89 18016]]
f1 score : 0.8452263664086325
validation loss: 0.4918071998071067
[[22  1]
 [57 20]]
f1 score : 0.40816326530612246
validation loss: 0.8566935658454895


                                                                                                     

[[  696  5584]
 [   96 18941]]
f1 score : 0.8696111289656122
validation loss: 0.41521965038927294
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9676960706710815


                                                                                                     

[[  698  5788]
 [   94 18737]]
f1 score : 0.8643325029984316
validation loss: 0.4161413065240353
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9578865766525269


                                                                                                     

[[  699  5602]
 [   93 18923]]
f1 score : 0.8692037390046163
validation loss: 0.4221421421328677
[[22  3]
 [57 18]]
f1 score : 0.375
validation loss: 0.8399122357368469


                                                                                                     

[[  707  5546]
 [   85 18979]]
f1 score : 0.8708160315675972
validation loss: 0.4094199056866803
[[29 10]
 [50 11]]
f1 score : 0.2682926829268293
validation loss: 0.821945309638977


                                                                                                     

[[  761  7001]
 [   31 17524]]
f1 score : 0.8328897338403042
validation loss: 0.4185828589185884
[[79 21]
 [ 0  0]]
f1 score : 0.0
validation loss: 0.7989631295204163


                                                                                                    

KeyboardInterrupt: 

In [34]:
model2 = RNNmodel(
    input_size=input_size,
    hidden_size=hidden_dim,
    output_size=output_dim,
    embedding_dim=embedding_dim,
    batch_size=128,
    num_layers=2
).to(device)


In [37]:
# train model 2 
loss_list2=[]
iteration = 0
validation_iter1 = 0
validation_iter2 = 0
accuracy_list = []
validation_loss1 = []
validation_loss2 = []
max_f1 = 0
optimizer2 = torch.optim.SGD(model2.parameters(), 1e-5, momentum=0.9)
error2 = nn.BCEWithLogitsLoss()
for epoch in range(1000):
    model2.train()
    loop2 = tqdm(enumerate(train_loader2), total=len(train_loader2),leave=False)
    
    for i, (elem2, labels2) in loop2:
        # elem = elem.reshape(batch_size,input_size,1)
        train_loss = []
        elem2=elem2.to(device)
        labels2 = Variable(torch.Tensor(labels2)).to(device)
        # Clear gradient
        optimizer2.zero_grad()
        # Forward
        logits = model2(elem2)
        labels2 = labels2.unsqueeze(1).to(device)
        predicted_labels2 = (logits>0.5).float()
        # Calculate cross entropy loss
        loss2 = error2(logits, labels2)
        # Calculating gradient
        loss2.backward()
        # Update parameters
        optimizer2.step()
        loop2.set_description(f"Epoch [{epoch}/{num_epochs}]")
        loop2.set_postfix(loss2=loss2.data)
        iteration += 1
        train_loss.append(loss2.tolist())
        if iteration % 100:
            loss_list2.append(sum(train_loss)/len(train_loss))
            train_loss = []
    val_loss1, val1_f1 = do_validation(valid_loader1,model2)
    validation_loss1 += val_loss1
    val_loss2, val2_f1 = do_validation(valid_loader2,model2)
    validation_loss2 += val_loss2
    if val1_f1 + val2_f1 > max_f1:
        max_f1 = val1_f1 + val2_f1
        torch.save(model2.state_dict(), './model_LSTM_maxpool')
    loss_list.append(sum(train_loss)/len(train_loss))
    if early_stopping.early_stop:
        print("We are at epoch:", epoch)
        break

                                                                                                  

[[  697  5625]
 [   95 18900]]
f1 score : 0.8685661764705883
validation loss: 0.4161389670040034
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9663030505180359


                                                                                                  

[[  697  5637]
 [   95 18888]]
f1 score : 0.8682541141858969
validation loss: 0.41615882998780357
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965644359588623


                                                                                                  

[[  696  5629]
 [   96 18896]]
f1 score : 0.8684422179837764
validation loss: 0.4161714555341986
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661828279495239


                                                                                                  

[[  696  5634]
 [   96 18891]]
f1 score : 0.8683121897407611
validation loss: 0.4161562948287288
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661934971809387


                                                                                                  

[[  696  5624]
 [   96 18901]]
f1 score : 0.8685722163503516
validation loss: 0.4161298268957983
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654159545898438


                                                                                                  

[[  696  5633]
 [   96 18892]]
f1 score : 0.8683381977799739
validation loss: 0.41615323990206177
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659562706947327


                                                                                                  

[[  697  5629]
 [   95 18896]]
f1 score : 0.8684621748322457
validation loss: 0.4161486675467672
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9662315845489502


                                                                                                  

[[  698  5624]
 [   94 18901]]
f1 score : 0.8686121323529412
validation loss: 0.4161547085906886
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.966044545173645


                                                                                                  

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.4161580733106106
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658309817314148


                                                                                                  

[[  698  5635]
 [   94 18890]]
f1 score : 0.8683260934519296
validation loss: 0.4161556784110733
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656754732131958


                                                                                                   

[[  698  5621]
 [   94 18904]]
f1 score : 0.8686901178687131
validation loss: 0.4161393366282499
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.96598881483078


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.41615381059767326
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661348462104797


                                                                                                   

[[  697  5631]
 [   95 18894]]
f1 score : 0.8684101668428552
validation loss: 0.4161575103107887
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660904407501221


                                                                                                   

[[  696  5631]
 [   96 18894]]
f1 score : 0.8683902102723199
validation loss: 0.416162075573885
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655219912528992


                                                                                                   

[[  697  5634]
 [   95 18891]]
f1 score : 0.8683321458941416
validation loss: 0.4161824335780325
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658991098403931


                                                                                                   

[[  697  5631]
 [   95 18894]]
f1 score : 0.8684101668428552
validation loss: 0.41616591320762153
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660435318946838


                                                                                                   

[[  697  5629]
 [   95 18896]]
f1 score : 0.8684621748322457
validation loss: 0.4161829030966457
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660089015960693


                                                                                                   

[[  698  5636]
 [   94 18889]]
f1 score : 0.8683000827434035
validation loss: 0.4161929688121699
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660292863845825


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.41619257851492003
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657336473464966


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.41617155346689344
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965979278087616


                                                                                                   

[[  697  5634]
 [   95 18891]]
f1 score : 0.8683321458941416
validation loss: 0.41617479860028134
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.966235339641571


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.4161796208423904
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9665476679801941


                                                                                                   

[[  696  5633]
 [   96 18892]]
f1 score : 0.8683381977799739
validation loss: 0.4161689554588704
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9662594199180603


                                                                                                   

[[  696  5636]
 [   96 18889]]
f1 score : 0.8682601700758446
validation loss: 0.416176476735103
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654934406280518


                                                                                                   

[[  697  5630]
 [   95 18895]]
f1 score : 0.8684361714351373
validation loss: 0.4161838837062256
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658135771751404


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41618473922150045
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661620855331421


                                                                                                   

[[  698  5630]
 [   94 18895]]
f1 score : 0.8684561290619112
validation loss: 0.41617839245856564
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658628702163696


                                                                                                   

[[  698  5630]
 [   94 18895]]
f1 score : 0.8684561290619112
validation loss: 0.41618491305580624
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658966064453125


                                                                                                   

[[  697  5631]
 [   95 18894]]
f1 score : 0.8684101668428552
validation loss: 0.4161852682693095
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657514691352844


                                                                                                   

[[  698  5633]
 [   94 18892]]
f1 score : 0.8683781112822044
validation loss: 0.4161835942087294
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.966505229473114


                                                                                                   

[[  697  5631]
 [   95 18894]]
f1 score : 0.8684101668428552
validation loss: 0.41620073009140884
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660452008247375


                                                                                                   

[[  696  5631]
 [   96 18894]]
f1 score : 0.8683902102723199
validation loss: 0.41618768323825883
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660204648971558


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.4161874799788753
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654699563980103


                                                                                                   

[[  696  5630]
 [   96 18895]]
f1 score : 0.8684162147256183
validation loss: 0.4161899768099
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657705426216125


                                                                                                   

[[  697  5632]
 [   95 18893]]
f1 score : 0.8683841610553169
validation loss: 0.41620064958741393
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660520553588867


                                                                                                   

[[  697  5634]
 [   95 18891]]
f1 score : 0.8683321458941416
validation loss: 0.41621597718588915
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660416841506958


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.41619862597199936
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660032987594604


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.4161952282809004
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661757946014404


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41619441275355185
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657267332077026


                                                                                                   

[[  696  5635]
 [   96 18890]]
f1 score : 0.8682861805060789
validation loss: 0.4162090629716463
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.966101884841919


                                                                                                   

[[  697  5621]
 [   95 18904]]
f1 score : 0.8686701589927396
validation loss: 0.41620171070098877
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657902121543884


                                                                                                   

[[  696  5631]
 [   96 18894]]
f1 score : 0.8683902102723199
validation loss: 0.4162205846249303
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658797979354858


                                                                                                   

[[  697  5632]
 [   95 18893]]
f1 score : 0.8683841610553169
validation loss: 0.4161867131160784
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9653084874153137


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.41620961918106564
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659242033958435


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.4161988652205165
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9662089943885803


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.4161946287638024
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660885334014893


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.4162110036687006
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659618735313416


                                                                                                   

[[  697  5638]
 [   95 18887]]
f1 score : 0.8682281012250903
validation loss: 0.4162077227725258
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656348824501038


                                                                                                   

[[  697  5625]
 [   95 18900]]
f1 score : 0.8685661764705883
validation loss: 0.416207849677605
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654204845428467


                                                                                                   

[[  697  5629]
 [   95 18896]]
f1 score : 0.8684621748322457
validation loss: 0.4162110150614871
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657947421073914


                                                                                                   

[[  696  5634]
 [   96 18891]]
f1 score : 0.8683121897407611
validation loss: 0.4162170293210428
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659988284111023


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.4162219811089431
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654587507247925


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.4162231597719313
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.96587735414505


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.4162091585654247
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657105803489685


                                                                                                   

[[  697  5634]
 [   95 18891]]
f1 score : 0.8683321458941416
validation loss: 0.41621875521502916
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656838774681091


                                                                                                   

[[  697  5633]
 [   95 18892]]
f1 score : 0.8683581540724398
validation loss: 0.41623268897020366
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660121202468872


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.41622049069102807
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965868353843689


                                                                                                   

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.4162242794338661
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657104015350342


                                                                                                   

[[  697  5630]
 [   95 18895]]
f1 score : 0.8684361714351373
validation loss: 0.4162176625638068
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658961892127991


                                                                                                   

[[  697  5634]
 [   95 18891]]
f1 score : 0.8683321458941416
validation loss: 0.4162333655960952
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9662827849388123


                                                                                                   

[[  698  5635]
 [   94 18890]]
f1 score : 0.8683260934519296
validation loss: 0.41621773227860653
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9661544561386108


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.4162441769732705
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657325744628906


                                                                                                   

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.4162389533429206
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658129811286926


                                                                                                   

[[  697  5637]
 [   95 18888]]
f1 score : 0.8682541141858969
validation loss: 0.41622382500503635
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9664207100868225


                                                                                                   

[[  697  5628]
 [   95 18897]]
f1 score : 0.8684881770342624
validation loss: 0.4162273519401309
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965509831905365


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.41624416052540647
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658819437026978


                                                                                                   

[[  696  5642]
 [   96 18883]]
f1 score : 0.8681040823832291
validation loss: 0.4162463578996779
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657185077667236


                                                                                                   

[[  697  5632]
 [   95 18893]]
f1 score : 0.8683841610553169
validation loss: 0.4162451489062249
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965498149394989


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.416241954927203
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660825729370117


                                                                                                   

[[  696  5631]
 [   96 18894]]
f1 score : 0.8683902102723199
validation loss: 0.41624569259112393
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658787250518799


                                                                                                   

[[  698  5636]
 [   94 18889]]
f1 score : 0.8683000827434035
validation loss: 0.4162428016904034
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658696055412292


                                                                                                   

[[  696  5635]
 [   96 18890]]
f1 score : 0.8682861805060789
validation loss: 0.41626209835462935
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965968906879425


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.4162491145013254
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660808444023132


                                                                                                   

[[  697  5632]
 [   95 18893]]
f1 score : 0.8683841610553169
validation loss: 0.41626042014435877
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660555720329285


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41623876517332054
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660572409629822


                                                                                                   

[[  696  5637]
 [   96 18888]]
f1 score : 0.868234158449976
validation loss: 0.41625624110427084
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659215807914734


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41625886152062236
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659579396247864


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.41627084092248845
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655364751815796


                                                                                                   

[[  697  5635]
 [   95 18890]]
f1 score : 0.8683061365203402
validation loss: 0.4162388413767271
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660012722015381


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41626159186604655
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660966396331787


                                                                                                   

[[  697  5636]
 [   95 18889]]
f1 score : 0.8682801259509526
validation loss: 0.41624976260752616
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659214615821838


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41627327052852775
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657689929008484


                                                                                                   

[[  697  5630]
 [   95 18895]]
f1 score : 0.8684361714351373
validation loss: 0.41626172020465513
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9662091135978699


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41627056402496143
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965846836566925


                                                                                                   

[[  697  5630]
 [   95 18895]]
f1 score : 0.8684361714351373
validation loss: 0.4162658863429782
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656041264533997


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41626190384732015
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655399918556213


                                                                                                   

[[  697  5643]
 [   95 18882]]
f1 score : 0.8680980184819089
validation loss: 0.41627579867085324
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658318758010864


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41626202705540233
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.966029167175293


                                                                                                   

[[  698  5634]
 [   94 18891]]
f1 score : 0.8683521029648356
validation loss: 0.41628029897243163
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654677510261536


                                                                                                   

[[  696  5638]
 [   96 18887]]
f1 score : 0.8682081456283902
validation loss: 0.4162736877610412
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9653207063674927


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41627886778191675
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658846855163574


                                                                                                   

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.4162665317330179
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9663143754005432


                                                                                                   

[[  696  5634]
 [   96 18891]]
f1 score : 0.8683121897407611
validation loss: 0.41626127279257474
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656310677528381


                                                                                                   

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.41628816980349864
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657647609710693


                                                                                                   

[[  696  5640]
 [   96 18885]]
f1 score : 0.8681561163977382
validation loss: 0.41629710091820243
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657955765724182


                                                                                                   

[[  697  5643]
 [   95 18882]]
f1 score : 0.8680980184819089
validation loss: 0.4162776861009718
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656212329864502


                                                                                                   

[[  697  5639]
 [   95 18886]]
f1 score : 0.8682020870684504
validation loss: 0.41627680863006206
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.965447187423706


                                                                                                   

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.4162813041783586
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659435153007507


                                                                                                   

[[  696  5639]
 [   96 18886]]
f1 score : 0.8681821316110052
validation loss: 0.4162890882431706
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659533500671387


                                                                                                   

[[  696  5635]
 [   96 18890]]
f1 score : 0.8682861805060789
validation loss: 0.41630059556116034
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656620621681213


                                                                                                    

[[  697  5637]
 [   95 18888]]
f1 score : 0.8682541141858969
validation loss: 0.41627673650089697
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657546877861023


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41629309956031507
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656456708908081


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41629720209520077
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658536911010742


                                                                                                    

[[  698  5643]
 [   94 18882]]
f1 score : 0.8681179742994414
validation loss: 0.4162843016129506
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9649637937545776


                                                                                                    

[[  696  5645]
 [   96 18880]]
f1 score : 0.8680260223902898
validation loss: 0.41632291335093824
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660050272941589


                                                                                                    

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.41630462483514713
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659325480461121


                                                                                                    

[[  697  5638]
 [   95 18887]]
f1 score : 0.8682281012250903
validation loss: 0.41630335940590385
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658713936805725


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.4162919928001452
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656704664230347


                                                                                                    

[[  696  5638]
 [   96 18887]]
f1 score : 0.8682081456283902
validation loss: 0.4162816852708406
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660078287124634


                                                                                                    

[[  697  5638]
 [   95 18887]]
f1 score : 0.8682281012250903
validation loss: 0.4163051578062999
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658178687095642


                                                                                                    

[[  698  5640]
 [   94 18885]]
f1 score : 0.8681960279514527
validation loss: 0.4162836051439937
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9654679894447327


                                                                                                    

[[  698  5637]
 [   94 18888]]
f1 score : 0.8682740708391753
validation loss: 0.4163055747370177
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657890200614929


                                                                                                    

[[  697  5638]
 [   95 18887]]
f1 score : 0.8682281012250903
validation loss: 0.41630544919001905
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9652913212776184


                                                                                                    

[[  696  5641]
 [   96 18884]]
f1 score : 0.868130099988507
validation loss: 0.4163056644457805
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655189514160156


                                                                                                    

[[  697  5642]
 [   95 18883]]
f1 score : 0.8681240374227065
validation loss: 0.41631697340856627
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657037854194641


                                                                                                    

[[  697  5637]
 [   95 18888]]
f1 score : 0.8682541141858969
validation loss: 0.4163062299354167
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9658631086349487


                                                                                                    

[[  697  5646]
 [   95 18879]]
f1 score : 0.8680199544817123
validation loss: 0.4163055965417548
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9659194946289062


                                                                                                    

[[  698  5644]
 [   94 18881]]
f1 score : 0.8680919540229886
validation loss: 0.41631191350236724
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9653860926628113


                                                                                                    

[[  696  5644]
 [   96 18881]]
f1 score : 0.868052043584203
validation loss: 0.41631411917601957
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655288457870483


                                                                                                    

[[  697  5643]
 [   95 18882]]
f1 score : 0.8680980184819089
validation loss: 0.4163193577452551
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657846689224243


                                                                                                    

[[  698  5648]
 [   94 18877]]
f1 score : 0.8679878609527313
validation loss: 0.4163338627996324
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9656216502189636


                                                                                                    

[[  698  5638]
 [   94 18887]]
f1 score : 0.8682480577391626
validation loss: 0.416328463297856
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9655342102050781


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41632240573062174
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657105803489685


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41632977708985536
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9657288193702698


                                                                                                    

[[  697  5640]
 [   95 18885]]
f1 score : 0.8681760717158947
validation loss: 0.41630953112735025
[[21  0]
 [58 21]]
f1 score : 0.41999999999999993
validation loss: 0.9660894274711609


                                                                                                    

KeyboardInterrupt: 

In [47]:
torch.save(model.state_dict(), './model_LSTM_maxpool')

In [40]:
output = []
for elem in test_loader1:
    elem=torch.Tensor(elem)
    elem = elem.to(device)
    predict = model(elem)
    if predict > 0.5:
        predict = 1
    else:
        predict = 0
    # predict = predict>=0.5,1,0
    output.append(predict)
    
for elem in test_loader2:
    elem=torch.Tensor(elem)
    elem = elem.to(device)
    predict = model2(elem)
    if predict > 0.5:
        predict = 1
    else:
        predict = 0
    # predict = predict>=0.5,1,0
    output.append(predict)    

result = pd.DataFrame(output)
display(result)

Unnamed: 0,0
0,0
1,0
2,0
3,0
4,0
...,...
995,1
996,1
997,0
998,0


In [41]:
pd.DataFrame.to_csv(result,'./LSTM_maxpool_2models.csv')

In [20]:
model.load_state_dict(torch.load('/kaggle/working/model_LSTM_maxpool'))

NameError: name 'X_train1' is not defined

In [None]:
result = model.predict(test)
# display(result)
result = pd.DataFrame(result, columns=['Predicted'],index=range(1000))


In [None]:
from matplotlib import pyplot as plt
# visualization loss
# print(loss_list)
# iteration_list = list(range(len(loss_list2)))
plt.plot(list(range(len(loss_list)))[::100],loss_list[::100])
plt.xlabel("Number of iteration")
plt.ylabel("Loss")
plt.title("LSTM: Loss vs Number of iteration")
plt.show()

# print(loss_list)
plt.plot(list(range(len(validation_loss1)))[::100],validation_loss1[::100])
plt.xlabel("Number of iteration")
plt.ylabel("Validation set1 Loss")
plt.title("LSTM: Loss vs Number of iteration")
plt.show()

# print(validation_loss2)
plt.plot(list(range(len(validation_loss2))),validation_loss2)
plt.xlabel("Number of iteration")
plt.ylabel("Validation set2 Loss")
plt.title("LSTM: Loss vs Number of iteration")
plt.show()


In [None]:
pd.DataFrame.to_csv(result,'./LR.csv')