# Import Packages

In [72]:
from configs import get_config
from data_loader import get_loader

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.autograd import Variable

from tqdm import tqdm

# Set Configuration

In [77]:
config = get_config(
    parse=False,
    vocab_size=20000,
    hidden_size=100,
    n_channel_per_window=2,
    label_size=2,
    dropout=0.5)

In [78]:
config

Configurations
{'batch_size': 100,
 'data_dir': PosixPath('/Users/jmin/workspace/fastcampus_chatbot/Day_02/CNN/datasets'),
 'dropout': 0.5,
 'epochs': 20,
 'hidden_size': 100,
 'label_size': 2,
 'log_every_epoch': 1,
 'loss_fn': <class 'torch.nn.modules.loss.CrossEntropyLoss'>,
 'lr': 0.001,
 'n_channel_per_window': 2,
 'optimizer': <class 'torch.optim.sgd.SGD'>,
 'save_dir': PosixPath('/Users/jmin/workspace/fastcampus_chatbot/Day_02/CNN/log'),
 'save_every_epoch': 1,
 'vocab_size': 20000}

# Load training data loader

In [80]:
train_loader = get_loader(batch_size=20, max_size=config.vocab_size, is_train=True, data_dir='./datasets/')

Building Vocabulary 



In [81]:
batch = next(iter(train_loader))
batch

<torchtext.data.batch.Batch at 0x1284279b0>

In [84]:
# [max_seq_len, batch_size]
batch.text

Variable containing:
  3249    127   2623  ...     524    213   3657
    20     12    159  ...      12     13      9
     6     27    152  ...     376    785    173
        ...            ⋱           ...         
     1      1      1  ...       1      1      1
     1      1      1  ...       1      1      1
     1      1      1  ...       1      1      1
[torch.LongTensor of size 71x20]

In [38]:
# [batch_size]
batch.label

Variable containing:
 0
 1
 1
 0
 0
 1
 0
 1
 0
 1
 0
 0
 1
 1
 1
 0
 1
 0
 1
 0
[torch.LongTensor of size 20]

## Model

<img src="../images/cnn_text_classification.png", width=600, height=60>

In [35]:
class CNN(nn.Module):
    def __init__(self, config):
        super(CNN, self).__init__()
        self.config = config
        
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        
        self.conv = nn.ModuleList([
            nn.Conv2d(
                in_channels=1,
                out_channels=config.n_channel_per_window,
                kernel_size=(3, config.hidden_size)),
            
            nn.Conv2d(
                in_channels=1,
                out_channels=config.n_channel_per_window,
                kernel_size=(4, config.hidden_size)),

            nn.Conv2d(
                in_channels=1,
                out_channels=config.n_channel_per_window,
                kernel_size=(5, config.hidden_size))
        ])
        
        n_total_channels = len(self.conv) * config.n_channel_per_window
        
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(n_total_channels, config.label_size)
        
    def forward(self, x):
        """
        Args:
            x: [batch_size, max_seq_len]
        Return:
            logit: [batch_size, label_size]
        """
        
        # [batch_size, max_seq_len, hidden_size]
        x = self.embedding(x)
        
        # [batch_size, 1, max_seq_len, hidden_size]
        x = x.unsqueeze(1)
        
        # Apply Convolution filter followed by Max-pool
        out_list = []
        for conv in self.conv:
            
            ########## Convolution #########
            
            # [batch_size, n_kernels, _, 1]
            x_ = F.relu(conv(x))
            
            # [batch_size, n_kernels, _]
            x_ = x_.squeeze(3)
            
            ########## Max-pool #########
            
            # [batch_size, n_kernels, 1]
            x_ = F.max_pool1d(x_, x_.size(2))
            
            # [batch_size, n_kernels]
            x_ = x_.squeeze(2)
            
            out_list.append(x_)
        
        # [batch_size, 3 x n_kernels]
        out = torch.cat(out_list, 1)
        
        ######## Dropout ########
        out = self.dropout(out)
        
        # [batch_size, label_size]
        logit = self.fc(out)
        
        return logit

In [85]:
model = CNN(config)

In [86]:
model

CNN (
  (embedding): Embedding(20000, 100)
  (conv): ModuleList (
    (0): Conv2d(1, 2, kernel_size=(3, 100), stride=(1, 1))
    (1): Conv2d(1, 2, kernel_size=(4, 100), stride=(1, 1))
    (2): Conv2d(1, 2, kernel_size=(5, 100), stride=(1, 1))
  )
  (dropout): Dropout (p = 0.5)
  (fc): Linear (6 -> 2)
)

# Build loss function

In [87]:
loss_fn = config.loss_fn()

loss_fn

CrossEntropyLoss (
)

# Build Optimizer 

In [88]:
optimizer = config.optimizer(model.parameters(), config.lr)
optimizer

<torch.optim.sgd.SGD at 0x13cb046d8>

In [89]:
for epoch in range(2): # n_epochs
    print(f'Epoch: {epoch}')
    for batch_i, batch in enumerate(tqdm(train_loader)):
        # text: [max_seq_len, batch_size]
        # label: [batch_size]
        text, label = batch.text, batch.label

        # [batch_size, max_seq_len]
        text.data.t_()
        
        # [batch_size, 2]
        logit = model(text)
        
        # Calculate loss
        batch_loss = loss_fn(logit, label)
        batch_loss.backward()
        optimizer.step()
        
        if (batch_i + 1) % 50 == 0:
            tqdm.write(f'batch loss: {batch_loss.data}')


  0%|          | 0/7302 [00:00<?, ?it/s]

Epoch: 0


[A
  0%|          | 1/7302 [00:00<42:00,  2.90it/s][A
  0%|          | 4/7302 [00:00<31:01,  3.92it/s][A
  0%|          | 7/7302 [00:00<23:15,  5.23it/s][A
  0%|          | 10/7302 [00:00<17:45,  6.84it/s][A
  0%|          | 13/7302 [00:00<13:39,  8.89it/s][A
  0%|          | 17/7302 [00:00<10:40, 11.37it/s][A
  0%|          | 20/7302 [00:01<08:41, 13.96it/s][A
  0%|          | 24/7302 [00:01<07:20, 16.53it/s][A
  0%|          | 27/7302 [00:01<06:42, 18.09it/s][A
  0%|          | 30/7302 [00:01<06:20, 19.11it/s][A
  0%|          | 33/7302 [00:01<05:50, 20.72it/s][A
  1%|          | 37/7302 [00:01<05:22, 22.53it/s][A
  1%|          | 40/7302 [00:01<05:40, 21.31it/s][A
  1%|          | 43/7302 [00:01<05:17, 22.84it/s][A
  1%|          | 46/7302 [00:02<04:55, 24.52it/s][A
  1%|          | 49/7302 [00:02<04:43, 25.57it/s][A
          
 11%|█▏        | 839/7302 [4:02:08<31:05:18, 17.32s/it]
  1%|          | 52/7302 [00:02<05:09, 23.42it/s][A

batch loss: 
 0.6101
[torch.FloatTensor of size 1]




  1%|          | 55/7302 [00:02<05:22, 22.49it/s][A
  1%|          | 58/7302 [00:02<05:12, 23.21it/s][A
  1%|          | 61/7302 [00:02<05:24, 22.30it/s][A
  1%|          | 64/7302 [00:02<05:22, 22.47it/s][A
  1%|          | 67/7302 [00:02<05:06, 23.64it/s][A
  1%|          | 70/7302 [00:03<05:01, 23.99it/s][A
  1%|          | 73/7302 [00:03<05:10, 23.31it/s][A
  1%|          | 76/7302 [00:03<05:11, 23.18it/s][A
  1%|          | 79/7302 [00:03<05:01, 23.97it/s][A
  1%|          | 82/7302 [00:03<05:26, 22.09it/s][A
  1%|          | 85/7302 [00:03<05:08, 23.36it/s][A
  1%|          | 88/7302 [00:03<04:57, 24.28it/s][A
  1%|▏         | 92/7302 [00:04<04:40, 25.71it/s][A
  1%|▏         | 95/7302 [00:04<04:37, 25.97it/s][A
  1%|▏         | 98/7302 [00:04<04:32, 26.45it/s][A
          
 11%|█▏        | 839/7302 [4:02:10<31:05:34, 17.32s/it]
  1%|▏         | 101/7302 [00:04<04:53, 24.55it/s][A
  1%|▏         | 104/7302 [00:04<04:58, 24.08it/s][A

batch loss: 
 0.7106
[torch.FloatTensor of size 1]




  1%|▏         | 107/7302 [00:04<04:54, 24.42it/s][A
  2%|▏         | 110/7302 [00:04<04:40, 25.65it/s][A
  2%|▏         | 113/7302 [00:04<04:55, 24.35it/s][A
  2%|▏         | 116/7302 [00:05<05:17, 22.63it/s][A
  2%|▏         | 120/7302 [00:05<04:53, 24.46it/s][A
  2%|▏         | 123/7302 [00:05<04:52, 24.51it/s][A
  2%|▏         | 126/7302 [00:05<04:42, 25.38it/s][A
  2%|▏         | 129/7302 [00:05<04:58, 24.03it/s][A
  2%|▏         | 132/7302 [00:05<05:01, 23.79it/s][A
  2%|▏         | 135/7302 [00:05<04:55, 24.25it/s][A
  2%|▏         | 139/7302 [00:05<04:36, 25.94it/s][A
  2%|▏         | 142/7302 [00:06<04:25, 26.94it/s][A
  2%|▏         | 145/7302 [00:06<04:42, 25.35it/s][A
  2%|▏         | 148/7302 [00:06<04:50, 24.61it/s][A
          
 11%|█▏        | 839/7302 [4:02:12<31:05:50, 17.32s/it]
  2%|▏         | 151/7302 [00:06<04:55, 24.21it/s][A
  2%|▏         | 154/7302 [00:06<04:49, 24.66it/s][A

batch loss: 
 0.7033
[torch.FloatTensor of size 1]




  2%|▏         | 157/7302 [00:06<05:11, 22.95it/s][A
  2%|▏         | 160/7302 [00:06<04:58, 23.95it/s][A
  2%|▏         | 163/7302 [00:06<04:44, 25.08it/s][A
  2%|▏         | 166/7302 [00:07<04:41, 25.36it/s][A
  2%|▏         | 169/7302 [00:07<04:37, 25.75it/s][A
  2%|▏         | 172/7302 [00:07<04:39, 25.49it/s][A
  2%|▏         | 175/7302 [00:07<04:54, 24.16it/s][A
  2%|▏         | 178/7302 [00:07<05:04, 23.38it/s][A
  2%|▏         | 181/7302 [00:07<05:03, 23.46it/s][A
  3%|▎         | 184/7302 [00:07<04:44, 25.06it/s][A
  3%|▎         | 187/7302 [00:07<04:33, 25.97it/s][A
  3%|▎         | 190/7302 [00:07<04:23, 26.99it/s][A
  3%|▎         | 193/7302 [00:08<04:16, 27.76it/s][A
  3%|▎         | 196/7302 [00:08<04:30, 26.25it/s][A
  3%|▎         | 199/7302 [00:08<05:01, 23.55it/s][A
          
 11%|█▏        | 839/7302 [4:02:14<31:06:05, 17.32s/it]
  3%|▎         | 202/7302 [00:08<04:54, 24.09it/s][A
  3%|▎         | 205/7302 [00:08<04:38, 25.47it/s][A

batch loss: 
 0.7690
[torch.FloatTensor of size 1]




  3%|▎         | 208/7302 [00:08<04:59, 23.71it/s][A
  3%|▎         | 211/7302 [00:08<05:08, 23.00it/s][A
  3%|▎         | 254/7302 [00:10<04:15, 27.57it/s]

batch loss: 
 0.8075
[torch.FloatTensor of size 1]



  4%|▍         | 305/7302 [00:12<04:33, 25.61it/s]

batch loss: 
 0.7300
[torch.FloatTensor of size 1]



  5%|▍         | 354/7302 [00:14<04:23, 26.35it/s]

batch loss: 
 0.6384
[torch.FloatTensor of size 1]



  6%|▌         | 404/7302 [00:16<04:23, 26.17it/s]

batch loss: 
 0.9197
[torch.FloatTensor of size 1]



  6%|▌         | 452/7302 [00:17<04:40, 24.42it/s]

batch loss: 
 0.7895
[torch.FloatTensor of size 1]



  7%|▋         | 503/7302 [00:19<04:03, 27.94it/s]

batch loss: 
 0.7296
[torch.FloatTensor of size 1]



  8%|▊         | 553/7302 [00:21<03:47, 29.65it/s]

batch loss: 
 0.7809
[torch.FloatTensor of size 1]



  8%|▊         | 604/7302 [00:23<04:27, 25.04it/s]

batch loss: 
 0.8771
[torch.FloatTensor of size 1]



  9%|▉         | 640/7302 [00:24<03:53, 28.48it/s]

KeyboardInterrupt: 

            9%|▉         | 640/7302 [00:38<06:44, 16.46it/s]