# <center>PROJECT SANDBOX</center>

## Documentation
The aim of this notebook is to provide a simple sandbox to test different NN architectures for the project. , here is a doc about the functions imported from `scripts` folder : 

- **`prepare_dataset(device,ratio=0.5,shuffle_ctx=False)`** :
    - **Input**:
        - device : a torch.device object
        - ratio : a float ratio between 0 and 1 that determines the average proportion of modern english verses in the data loader
        - shuffle_ctx : if `True`, shuffle the contexts within a Batch so that half of the `x_1` elements has a wrong context `ctx_1`. Useful to train the context recognizer model.
    - **Return** :
        - a torch Dataset | class : Shakespeare inherited from torch.utils.data.Dataset
        - a python word dictionary (aka tokenizer) | class : dict
    - **Tensors returned when loaded in the dataloader**:
        - x_1 : input verse (modern / shakespearian)
        - x_2 : output verse (modern / shakespearian)

        - ctx_1 = context of the input verse
        - ctx_2 = context of the output verse

        - len_x : length of the input verse
        - len_y : length of the output verse

        - len_ctx_x : length of the input verse context
        - len_ctx_y : length of the output verse context

        - label : label of the input verse (0 : modern, 1 : shakespearian)
        - label_ctx : label of the context (0 : wrong context, 1 : right context)
- **`string2code(string,dict)`** : 
    - **Input**:
        - string : a sentence
        - dict : a tokenizer
    - **Return** :
        - a torch Longtensor (sentence tokenized)
- **`code2string(torch.Longtensor,dict)`** : 
    - **Input**:
        - torch.Longtensor : a sentence tokenized
        - dict : a tokenizer
    - **Return** :
        - a string sentence

## Importing packages

In [1]:
from scripts.data_builders.prepare_dataset import prepare_dataset,string2code,code2string,assemble

import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.nn import BCELoss
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pickle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device = ",device)

device =  cuda


## Preprocessing data

In [2]:
train_data, dict_words = prepare_dataset(device,ratio=0.5,shuffle_ctx=True) #check with shift+tab to look at the data structure
batch_size = 128
dict_token = {b:a for a,b in dict_words.items()} #dict for code2string

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                           shuffle=True,collate_fn=train_data.collate)

Loading ...
- Shakespeare dataset length :  21079
- Corrupted samples (ignored) :  0


## Designing NN model

### Language Model 

In [3]:
dict_size = len(dict_words) #19089
d_embedding = 300 #cf. paper Y.Kim 2014 Convolutional Neural Networks for Sentence Classification


In [4]:
class CoherenceClassifier(torch.nn.Module):
    def __init__(self,dict_size=dict_size,d_embedding=300):
        super().__init__()
        self.embed_layer=torch.nn.Embedding(dict_size+1,d_embedding,padding_idx=dict_size)

        self.conv_1 = torch.nn.Conv1d(d_embedding,3,kernel_size = 3, stride = 1)
        self.max_pool = torch.nn.MaxPool1d(3,2)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(3,1)
        # self.f=lambda x: torch.norm(x,dim=1)**2 (I am not sure it is necessary at all)
        self.f = lambda x:x
        self.sigmoid=torch.nn.Sigmoid()
    
    def forward(self,x,ctx):
        x = self.embed_layer(x)
        ctx = self.embed_layer(ctx)
        x = torch.cat((self.f(x),ctx),dim=1)
        x = self.conv_1( x.transpose(1,2))
        x = self.max_pool( x )
        x = self.relu( x )
        x = torch.max( x , 2 )[0]
        x = self.sigmoid(self.linear(x))
        return(x)

        
    

In [5]:
class CoherenceClassifier(torch.nn.Module):
    def __init__(self,dict_size=dict_size,d_embedding=300,d_hidden=100):
        super().__init__()
        self.d_hidden = d_hidden
        self.embedding = nn.Embedding(dict_size+1,d_embedding,padding_idx=dict_size)
        self.lstm = nn.LSTM(d_embedding,self.d_hidden,dropout=0.,num_layers=1,bidirectional=False)
        self.linear = torch.nn.Linear(self.d_hidden,1)
    
    def forward(self,x,len_x):
        x = self.embedding(x)
        x = pack_padded_sequence(x.permute(1,0,2),len_x,enforce_sorted=False)
        _,x = self.lstm(x)
        x = x[0].reshape(-1,self.d_hidden)
        x = torch.sigmoid( self.linear(x) ).reshape(-1)
        
        return x



## Running model

In [None]:
for x,y , ctx_x,ctx_y , len_x,len_y , len_ctx_x,len_ctx_y, label,label_ctx in train_loader:
    
    for i in range(x.shape[0]):
        print("\n- x :")
        print(code2string(x[i],dict_token))
        print("- context of x :")
        print(code2string(ctx_x[i],dict_token))
        print("- context label :",label_ctx[i].item())

        print("- len_ctx_x :")
        print(ctx_x)
#         ipdb.set_trace()
    break

In [6]:
epochs=100
model=CoherenceClassifier().to(device)
optimizer= optim.Adam(params=model.parameters(),lr=0.001)
loss_func=BCELoss()

In [7]:
n = len(train_data.x) // batch_size

for epoch in range(epochs):
    total_loss = 0
    
    for x,y , ctx_x,ctx_y , len_x,len_y , len_ctx_x,len_ctx_y, label,label_ctx in train_loader:
        optimizer.zero_grad()
        
        x = model.forward(x,len_x)
        loss = loss_func(x,label_ctx.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(epoch,"\t",round(total_loss/n,5))

0 	 0.69855
1 	 0.69757
2 	 0.69741
3 	 0.69733
4 	 0.69755
5 	 0.69745
6 	 0.6974
7 	 0.69739
8 	 0.69747
9 	 0.69746
10 	 0.69745
11 	 0.69739
12 	 0.69744
13 	 0.69742
14 	 0.69731
15 	 0.69744
16 	 0.69743
17 	 0.69742
18 	 0.69741
19 	 0.69742
20 	 0.69744
21 	 0.69741
22 	 0.69741
23 	 0.69741
24 	 0.69742
25 	 0.69743
26 	 0.69738
27 	 0.69746
28 	 0.69741
29 	 0.69741
30 	 0.6974
31 	 0.69742
32 	 0.69742
33 	 0.6974
34 	 0.6974
35 	 0.69739
36 	 0.69741
37 	 0.69742
38 	 0.69738
39 	 0.69738
40 	 0.69741
41 	 0.69742
42 	 0.69741
43 	 0.69737
44 	 0.69745
45 	 0.6974
46 	 0.69738
47 	 0.69746
48 	 0.69739
49 	 0.69743
50 	 0.69742
51 	 0.69742
52 	 0.69743
53 	 0.69744
54 	 0.69743
55 	 0.69736
56 	 0.69743
57 	 0.69745
58 	 0.69744
59 	 0.69743
60 	 0.69741
61 	 0.69743
62 	 0.69742
63 	 0.6974
64 	 0.69743
65 	 0.69742
66 	 0.69745
67 	 0.6974
68 	 0.69739
69 	 0.69742
70 	 0.69734
71 	 0.69747
72 	 0.69742
73 	 0.69742
74 	 0.69734
75 	 0.69745
76 	 0.69737
77 	 0.69746
78 