In [1]:
!pip install transformers



In [2]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel
import torch.nn as nn
from tqdm import tqdm, tqdm_notebook

In [3]:
U_TKN = '<usr>'
S_TKN = '<sys>'
BOS = '</s>'
EOS = '</s>'
MASK = '<unused0>'
SENT = '<unused1>'
PAD = '<pad>'

In [4]:
TOKENIZER = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2",
            bos_token=BOS, eos_token=EOS, unk_token='<unk>',
            pad_token=PAD, mask_token=MASK)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [5]:
class CharDataset(Dataset):
    def __init__(self, chats, max_len=32):
        self._data = chats
        self.first = True
        self.q_token = U_TKN
        self.a_token = S_TKN
        self.sent_token = SENT
        self.bos = BOS
        self.eos = EOS
        self.mask = MASK
        self.pad = PAD
        self.max_len = max_len
        self.tokenizer = TOKENIZER 

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        turn = self._data.iloc[idx]
        q = turn['Q']
        a = turn['A']
        #sentiment = str(turn['label'])
        q_toked = self.tokenizer.tokenize(self.q_token + q)   
        q_len = len(q_toked)
        a_toked = self.tokenizer.tokenize(self.a_token + a + self.eos)
        a_len = len(a_toked)
        if q_len + a_len > self.max_len:
            a_len = self.max_len - q_len
            if a_len <= 0:
                q_toked = q_toked[-(int(self.max_len/2)):]
                q_len = len(q_toked)
                a_len = self.max_len - q_len
                assert a_len > 0
            a_toked = a_toked[:a_len]
            a_len = len(a_toked)
            assert a_len == len(a_toked), f'{a_len} ==? {len(a_toked)}'
        # [mask, mask, ...., mask, ..., <bos>,..A.. <eos>, <pad>....]
        labels = [
            self.mask,
        ] * q_len + a_toked[1:]
        mask = [0] * q_len + [1] * a_len + [0] * (self.max_len - q_len - a_len)
        self.max_len
        labels_ids = self.tokenizer.convert_tokens_to_ids(labels)
        while len(labels_ids) < self.max_len:
            labels_ids += [self.tokenizer.pad_token_id]
        token_ids = self.tokenizer.convert_tokens_to_ids(q_toked + a_toked)
        while len(token_ids) < self.max_len:
            token_ids += [self.tokenizer.pad_token_id]
        return(token_ids, np.array(mask),
               labels_ids)

In [6]:
class KoGPT2Chat(nn.Module):
    def __init__(self):
        super(KoGPT2Chat, self).__init__()
        self.neg = -1e18
        self.kogpt2 = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')
        self.loss_function = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs):
        # (batch, seq_len, hiddens)
        output = self.kogpt2(inputs, return_dict=True)
        return output.logits

In [7]:
model = KoGPT2Chat().cuda()

In [8]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [9]:
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5)



In [10]:
loss_fn = nn.CrossEntropyLoss()

In [11]:
data = pd.read_csv('/content/drive/MyDrive/자연어처리음성인식/Trash/concatData.csv')

In [12]:
train_set = CharDataset(data, max_len=32)

In [13]:
def _collate_fn(batch):
    data = [item[0] for item in batch]
    mask = [item[1] for item in batch]
    label = [item[2] for item in batch]
    return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)

In [14]:
train_dataloader = DataLoader(train_set, batch_size=96, num_workers=2,shuffle=True, collate_fn=_collate_fn)

In [15]:
num_epochs = 20

In [16]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * 2)

In [17]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [18]:
neg = -1e18

In [19]:
log_interval = 100

In [None]:
max_loss = 10000
for e in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_id, (token_ids, mask, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().cuda()
        mask = mask.long().cuda()
        label = label.long().cuda()
        out = model(token_ids)
        mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2)
        mask_out = torch.where(mask_3d == 1, out, neg * torch.ones_like(out))
        loss = loss_fn(mask_out.transpose(2, 1), label)
        loss.backward()
        train_loss += loss
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        if batch_id % log_interval == 0:
            print("EPOCH {} [{}/{}]  >>>  loss : {:.6f}\t  train_loss : {:.3f}".format(e+1, batch_id+1,len(train_dataloader),
                                                                           loss.data.cpu().numpy(), train_loss / (batch_id+1)))
    print("EPOCH {}  >>>  loss : {:.6f}\t  train_loss : {:.3f}".format(e+1, loss.data.cpu().numpy(), train_loss / (batch_id+1)))
    if max_loss > train_loss / (batch_id+1):
      max_loss = train_loss / (batch_id+1)
      torch.save(model.state_dict(), '/content/drive/MyDrive/자연어처리음성인식/조상연/TorchGPT2.pt')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 1 [1/1324]  >>>  loss : 9.004754	  train_loss : 9.005
EPOCH 1 [101/1324]  >>>  loss : 8.962007	  train_loss : 8.945
EPOCH 1 [201/1324]  >>>  loss : 8.594399	  train_loss : 8.842
EPOCH 1 [301/1324]  >>>  loss : 8.380759	  train_loss : 8.720
EPOCH 1 [401/1324]  >>>  loss : 8.354620	  train_loss : 8.591
EPOCH 1 [501/1324]  >>>  loss : 8.223542	  train_loss : 8.470
EPOCH 1 [601/1324]  >>>  loss : 8.001460	  train_loss : 8.372
EPOCH 1 [701/1324]  >>>  loss : 7.972234	  train_loss : 8.291
EPOCH 1 [801/1324]  >>>  loss : 7.834918	  train_loss : 8.225
EPOCH 1 [901/1324]  >>>  loss : 7.780523	  train_loss : 8.171
EPOCH 1 [1001/1324]  >>>  loss : 7.589610	  train_loss : 8.121
EPOCH 1 [1101/1324]  >>>  loss : 7.794427	  train_loss : 8.080
EPOCH 1 [1201/1324]  >>>  loss : 7.839632	  train_loss : 8.043
EPOCH 1 [1301/1324]  >>>  loss : 7.774604	  train_loss : 8.011
EPOCH 1  >>>  loss : 7.561442	  train_loss : 8.004


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 2 [1/1324]  >>>  loss : 7.479500	  train_loss : 7.480
EPOCH 2 [101/1324]  >>>  loss : 7.547923	  train_loss : 7.602
EPOCH 2 [201/1324]  >>>  loss : 7.517080	  train_loss : 7.591
EPOCH 2 [301/1324]  >>>  loss : 7.491930	  train_loss : 7.588
EPOCH 2 [401/1324]  >>>  loss : 7.503162	  train_loss : 7.577
EPOCH 2 [501/1324]  >>>  loss : 7.364359	  train_loss : 7.574
EPOCH 2 [601/1324]  >>>  loss : 7.530267	  train_loss : 7.564
EPOCH 2 [701/1324]  >>>  loss : 7.368507	  train_loss : 7.556
EPOCH 2 [801/1324]  >>>  loss : 7.586389	  train_loss : 7.552
EPOCH 2 [901/1324]  >>>  loss : 7.734723	  train_loss : 7.548
EPOCH 2 [1001/1324]  >>>  loss : 7.487865	  train_loss : 7.544
EPOCH 2 [1101/1324]  >>>  loss : 7.425124	  train_loss : 7.540
EPOCH 2 [1201/1324]  >>>  loss : 7.576977	  train_loss : 7.535
EPOCH 2 [1301/1324]  >>>  loss : 7.604053	  train_loss : 7.531
EPOCH 2  >>>  loss : 7.639875	  train_loss : 7.530


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 3 [1/1324]  >>>  loss : 7.689527	  train_loss : 7.690
EPOCH 3 [101/1324]  >>>  loss : 7.409799	  train_loss : 7.481
EPOCH 3 [201/1324]  >>>  loss : 7.695319	  train_loss : 7.479
EPOCH 3 [301/1324]  >>>  loss : 7.698575	  train_loss : 7.470
EPOCH 3 [401/1324]  >>>  loss : 7.405835	  train_loss : 7.465
EPOCH 3 [501/1324]  >>>  loss : 7.624668	  train_loss : 7.465
EPOCH 3 [601/1324]  >>>  loss : 7.528747	  train_loss : 7.466
EPOCH 3 [701/1324]  >>>  loss : 7.309188	  train_loss : 7.463
EPOCH 3 [801/1324]  >>>  loss : 7.233407	  train_loss : 7.463
EPOCH 3 [901/1324]  >>>  loss : 7.571120	  train_loss : 7.459
EPOCH 3 [1001/1324]  >>>  loss : 7.371071	  train_loss : 7.457
EPOCH 3 [1101/1324]  >>>  loss : 7.428599	  train_loss : 7.456
EPOCH 3 [1201/1324]  >>>  loss : 7.393484	  train_loss : 7.455
EPOCH 3 [1301/1324]  >>>  loss : 7.380447	  train_loss : 7.453
EPOCH 3  >>>  loss : 7.330264	  train_loss : 7.452


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 4 [1/1324]  >>>  loss : 7.518835	  train_loss : 7.519
EPOCH 4 [101/1324]  >>>  loss : 7.425248	  train_loss : 7.421
EPOCH 4 [201/1324]  >>>  loss : 7.466825	  train_loss : 7.425
EPOCH 4 [301/1324]  >>>  loss : 7.428236	  train_loss : 7.423
EPOCH 4 [401/1324]  >>>  loss : 7.372717	  train_loss : 7.424
EPOCH 4 [501/1324]  >>>  loss : 7.546734	  train_loss : 7.426
EPOCH 4 [601/1324]  >>>  loss : 7.363872	  train_loss : 7.424
EPOCH 4 [701/1324]  >>>  loss : 7.419014	  train_loss : 7.422
EPOCH 4 [801/1324]  >>>  loss : 7.410894	  train_loss : 7.423
EPOCH 4 [901/1324]  >>>  loss : 7.486710	  train_loss : 7.423
EPOCH 4 [1001/1324]  >>>  loss : 7.446867	  train_loss : 7.421
EPOCH 4 [1101/1324]  >>>  loss : 7.455899	  train_loss : 7.420
EPOCH 4 [1201/1324]  >>>  loss : 7.573426	  train_loss : 7.419
EPOCH 4 [1301/1324]  >>>  loss : 7.364918	  train_loss : 7.418
EPOCH 4  >>>  loss : 7.500346	  train_loss : 7.417


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 5 [1/1324]  >>>  loss : 7.533074	  train_loss : 7.533
EPOCH 5 [101/1324]  >>>  loss : 7.383386	  train_loss : 7.407
EPOCH 5 [201/1324]  >>>  loss : 7.260992	  train_loss : 7.409
EPOCH 5 [301/1324]  >>>  loss : 7.369443	  train_loss : 7.401
EPOCH 5 [401/1324]  >>>  loss : 7.233860	  train_loss : 7.400
EPOCH 5 [501/1324]  >>>  loss : 7.289654	  train_loss : 7.392
EPOCH 5 [601/1324]  >>>  loss : 7.385302	  train_loss : 7.389
EPOCH 5 [701/1324]  >>>  loss : 7.277678	  train_loss : 7.391
EPOCH 5 [801/1324]  >>>  loss : 7.325247	  train_loss : 7.393
EPOCH 5 [901/1324]  >>>  loss : 7.398857	  train_loss : 7.394
EPOCH 5 [1001/1324]  >>>  loss : 7.460894	  train_loss : 7.394
EPOCH 5 [1101/1324]  >>>  loss : 7.397470	  train_loss : 7.394
EPOCH 5 [1201/1324]  >>>  loss : 7.375503	  train_loss : 7.394
EPOCH 5 [1301/1324]  >>>  loss : 7.555219	  train_loss : 7.391
EPOCH 5  >>>  loss : 7.301095	  train_loss : 7.390


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 6 [1/1324]  >>>  loss : 7.429166	  train_loss : 7.429
EPOCH 6 [101/1324]  >>>  loss : 7.326519	  train_loss : 7.363
EPOCH 6 [201/1324]  >>>  loss : 7.340415	  train_loss : 7.366
EPOCH 6 [301/1324]  >>>  loss : 7.322943	  train_loss : 7.363
EPOCH 6 [401/1324]  >>>  loss : 7.549225	  train_loss : 7.365
EPOCH 6 [501/1324]  >>>  loss : 7.500753	  train_loss : 7.366
EPOCH 6 [601/1324]  >>>  loss : 7.390862	  train_loss : 7.369
EPOCH 6 [701/1324]  >>>  loss : 7.296257	  train_loss : 7.368
EPOCH 6 [801/1324]  >>>  loss : 7.632968	  train_loss : 7.368
EPOCH 6 [901/1324]  >>>  loss : 7.334387	  train_loss : 7.368
EPOCH 6 [1001/1324]  >>>  loss : 7.102043	  train_loss : 7.367
EPOCH 6 [1101/1324]  >>>  loss : 7.365180	  train_loss : 7.368
EPOCH 6 [1201/1324]  >>>  loss : 7.297596	  train_loss : 7.367
EPOCH 6 [1301/1324]  >>>  loss : 7.333791	  train_loss : 7.367
EPOCH 6  >>>  loss : 7.430753	  train_loss : 7.367


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 7 [1/1324]  >>>  loss : 7.319280	  train_loss : 7.319
EPOCH 7 [101/1324]  >>>  loss : 7.274792	  train_loss : 7.352
EPOCH 7 [201/1324]  >>>  loss : 7.303476	  train_loss : 7.354
EPOCH 7 [301/1324]  >>>  loss : 7.240890	  train_loss : 7.347
EPOCH 7 [401/1324]  >>>  loss : 7.224344	  train_loss : 7.347
EPOCH 7 [501/1324]  >>>  loss : 7.613572	  train_loss : 7.347
EPOCH 7 [601/1324]  >>>  loss : 7.226639	  train_loss : 7.346
EPOCH 7 [701/1324]  >>>  loss : 7.392847	  train_loss : 7.345
EPOCH 7 [801/1324]  >>>  loss : 7.223083	  train_loss : 7.348
EPOCH 7 [901/1324]  >>>  loss : 7.206172	  train_loss : 7.347
EPOCH 7 [1001/1324]  >>>  loss : 7.171844	  train_loss : 7.346
EPOCH 7 [1101/1324]  >>>  loss : 7.413023	  train_loss : 7.347
EPOCH 7 [1201/1324]  >>>  loss : 7.664667	  train_loss : 7.347
EPOCH 7 [1301/1324]  >>>  loss : 7.251233	  train_loss : 7.345
EPOCH 7  >>>  loss : 7.447299	  train_loss : 7.345


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 8 [1/1324]  >>>  loss : 7.268202	  train_loss : 7.268
EPOCH 8 [101/1324]  >>>  loss : 7.483477	  train_loss : 7.310
EPOCH 8 [201/1324]  >>>  loss : 7.279978	  train_loss : 7.320
EPOCH 8 [301/1324]  >>>  loss : 7.430490	  train_loss : 7.316
EPOCH 8 [401/1324]  >>>  loss : 7.327652	  train_loss : 7.321
EPOCH 8 [501/1324]  >>>  loss : 7.216246	  train_loss : 7.322
EPOCH 8 [601/1324]  >>>  loss : 7.333811	  train_loss : 7.324
EPOCH 8 [701/1324]  >>>  loss : 7.415348	  train_loss : 7.327
EPOCH 8 [801/1324]  >>>  loss : 7.273426	  train_loss : 7.327
EPOCH 8 [901/1324]  >>>  loss : 7.224213	  train_loss : 7.326
EPOCH 8 [1001/1324]  >>>  loss : 7.422332	  train_loss : 7.324
EPOCH 8 [1101/1324]  >>>  loss : 7.199059	  train_loss : 7.323
EPOCH 8 [1201/1324]  >>>  loss : 7.230612	  train_loss : 7.325
EPOCH 8 [1301/1324]  >>>  loss : 7.198141	  train_loss : 7.324
EPOCH 8  >>>  loss : 7.424120	  train_loss : 7.324


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 9 [1/1324]  >>>  loss : 7.069010	  train_loss : 7.069
EPOCH 9 [101/1324]  >>>  loss : 7.216948	  train_loss : 7.299
EPOCH 9 [201/1324]  >>>  loss : 7.263527	  train_loss : 7.297
EPOCH 9 [301/1324]  >>>  loss : 7.340974	  train_loss : 7.297
EPOCH 9 [401/1324]  >>>  loss : 7.213786	  train_loss : 7.299
EPOCH 9 [501/1324]  >>>  loss : 7.304029	  train_loss : 7.295
EPOCH 9 [601/1324]  >>>  loss : 7.366665	  train_loss : 7.298
EPOCH 9 [701/1324]  >>>  loss : 7.223237	  train_loss : 7.298
EPOCH 9 [801/1324]  >>>  loss : 7.481496	  train_loss : 7.299
EPOCH 9 [901/1324]  >>>  loss : 7.258973	  train_loss : 7.299
EPOCH 9 [1001/1324]  >>>  loss : 7.299471	  train_loss : 7.300
EPOCH 9 [1101/1324]  >>>  loss : 7.338610	  train_loss : 7.302
EPOCH 9 [1201/1324]  >>>  loss : 7.382015	  train_loss : 7.304
EPOCH 9 [1301/1324]  >>>  loss : 7.300981	  train_loss : 7.303
EPOCH 9  >>>  loss : 7.533584	  train_loss : 7.303


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 10 [1/1324]  >>>  loss : 7.273294	  train_loss : 7.273
EPOCH 10 [101/1324]  >>>  loss : 7.131752	  train_loss : 7.273
EPOCH 10 [201/1324]  >>>  loss : 7.513454	  train_loss : 7.276
EPOCH 10 [301/1324]  >>>  loss : 7.282618	  train_loss : 7.274
EPOCH 10 [401/1324]  >>>  loss : 7.091747	  train_loss : 7.272
EPOCH 10 [501/1324]  >>>  loss : 7.372317	  train_loss : 7.275
EPOCH 10 [601/1324]  >>>  loss : 7.294901	  train_loss : 7.277
EPOCH 10 [701/1324]  >>>  loss : 7.242230	  train_loss : 7.278
EPOCH 10 [801/1324]  >>>  loss : 7.205026	  train_loss : 7.281
EPOCH 10 [901/1324]  >>>  loss : 7.454824	  train_loss : 7.281
EPOCH 10 [1001/1324]  >>>  loss : 7.276220	  train_loss : 7.280
EPOCH 10 [1101/1324]  >>>  loss : 7.269432	  train_loss : 7.281
EPOCH 10 [1201/1324]  >>>  loss : 7.389666	  train_loss : 7.282
EPOCH 10 [1301/1324]  >>>  loss : 7.453498	  train_loss : 7.281
EPOCH 10  >>>  loss : 6.928863	  train_loss : 7.282


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 11 [1/1324]  >>>  loss : 7.268239	  train_loss : 7.268
EPOCH 11 [101/1324]  >>>  loss : 7.344271	  train_loss : 7.269
EPOCH 11 [201/1324]  >>>  loss : 7.259857	  train_loss : 7.247
EPOCH 11 [301/1324]  >>>  loss : 7.261972	  train_loss : 7.247
EPOCH 11 [401/1324]  >>>  loss : 7.441783	  train_loss : 7.248
EPOCH 11 [501/1324]  >>>  loss : 7.443061	  train_loss : 7.250
EPOCH 11 [601/1324]  >>>  loss : 7.319941	  train_loss : 7.253
EPOCH 11 [701/1324]  >>>  loss : 7.348902	  train_loss : 7.255
EPOCH 11 [801/1324]  >>>  loss : 7.107146	  train_loss : 7.255
EPOCH 11 [901/1324]  >>>  loss : 7.282219	  train_loss : 7.257
EPOCH 11 [1001/1324]  >>>  loss : 7.328679	  train_loss : 7.258
EPOCH 11 [1101/1324]  >>>  loss : 7.321667	  train_loss : 7.260
EPOCH 11 [1201/1324]  >>>  loss : 7.286997	  train_loss : 7.259
EPOCH 11 [1301/1324]  >>>  loss : 7.346123	  train_loss : 7.260
EPOCH 11  >>>  loss : 7.132276	  train_loss : 7.260


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 12 [1/1324]  >>>  loss : 7.326019	  train_loss : 7.326
EPOCH 12 [101/1324]  >>>  loss : 7.089913	  train_loss : 7.234
EPOCH 12 [201/1324]  >>>  loss : 7.383419	  train_loss : 7.231
EPOCH 12 [301/1324]  >>>  loss : 7.256878	  train_loss : 7.231
EPOCH 12 [401/1324]  >>>  loss : 7.099374	  train_loss : 7.233
EPOCH 12 [501/1324]  >>>  loss : 7.072784	  train_loss : 7.234
EPOCH 12 [601/1324]  >>>  loss : 6.969776	  train_loss : 7.235
EPOCH 12 [701/1324]  >>>  loss : 7.222501	  train_loss : 7.238
EPOCH 12 [801/1324]  >>>  loss : 7.467331	  train_loss : 7.238
EPOCH 12 [901/1324]  >>>  loss : 7.351922	  train_loss : 7.239
EPOCH 12 [1001/1324]  >>>  loss : 7.230734	  train_loss : 7.238
EPOCH 12 [1101/1324]  >>>  loss : 7.203903	  train_loss : 7.239
EPOCH 12 [1201/1324]  >>>  loss : 7.110007	  train_loss : 7.239
EPOCH 12 [1301/1324]  >>>  loss : 7.167208	  train_loss : 7.238
EPOCH 12  >>>  loss : 7.058883	  train_loss : 7.238


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 13 [1/1324]  >>>  loss : 7.317340	  train_loss : 7.317
EPOCH 13 [101/1324]  >>>  loss : 7.238384	  train_loss : 7.204
EPOCH 13 [201/1324]  >>>  loss : 7.080784	  train_loss : 7.207
EPOCH 13 [301/1324]  >>>  loss : 7.124874	  train_loss : 7.207
EPOCH 13 [401/1324]  >>>  loss : 7.153809	  train_loss : 7.207
EPOCH 13 [501/1324]  >>>  loss : 7.108938	  train_loss : 7.206
EPOCH 13 [601/1324]  >>>  loss : 7.031200	  train_loss : 7.208
EPOCH 13 [701/1324]  >>>  loss : 7.167429	  train_loss : 7.206
EPOCH 13 [801/1324]  >>>  loss : 7.258855	  train_loss : 7.210
EPOCH 13 [901/1324]  >>>  loss : 7.143007	  train_loss : 7.211
EPOCH 13 [1001/1324]  >>>  loss : 7.384207	  train_loss : 7.213
EPOCH 13 [1101/1324]  >>>  loss : 7.255016	  train_loss : 7.213
EPOCH 13 [1201/1324]  >>>  loss : 7.251708	  train_loss : 7.214
EPOCH 13 [1301/1324]  >>>  loss : 7.158675	  train_loss : 7.216
EPOCH 13  >>>  loss : 7.267702	  train_loss : 7.216


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 14 [1/1324]  >>>  loss : 7.137805	  train_loss : 7.138
EPOCH 14 [101/1324]  >>>  loss : 7.219689	  train_loss : 7.162
EPOCH 14 [201/1324]  >>>  loss : 7.151935	  train_loss : 7.169
EPOCH 14 [301/1324]  >>>  loss : 7.210392	  train_loss : 7.172
EPOCH 14 [401/1324]  >>>  loss : 7.218712	  train_loss : 7.176
EPOCH 14 [501/1324]  >>>  loss : 7.067774	  train_loss : 7.179
EPOCH 14 [601/1324]  >>>  loss : 7.180995	  train_loss : 7.183
EPOCH 14 [701/1324]  >>>  loss : 7.138132	  train_loss : 7.185
EPOCH 14 [801/1324]  >>>  loss : 7.256259	  train_loss : 7.189
EPOCH 14 [901/1324]  >>>  loss : 7.005255	  train_loss : 7.188
EPOCH 14 [1001/1324]  >>>  loss : 7.110002	  train_loss : 7.191
EPOCH 14 [1101/1324]  >>>  loss : 7.096103	  train_loss : 7.191
EPOCH 14 [1201/1324]  >>>  loss : 7.172590	  train_loss : 7.193
EPOCH 14 [1301/1324]  >>>  loss : 7.258641	  train_loss : 7.194
EPOCH 14  >>>  loss : 7.132412	  train_loss : 7.194


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 15 [1/1324]  >>>  loss : 7.275088	  train_loss : 7.275
EPOCH 15 [101/1324]  >>>  loss : 7.027993	  train_loss : 7.152
EPOCH 15 [201/1324]  >>>  loss : 7.283474	  train_loss : 7.149
EPOCH 15 [301/1324]  >>>  loss : 7.061711	  train_loss : 7.155
EPOCH 15 [401/1324]  >>>  loss : 7.102764	  train_loss : 7.158
EPOCH 15 [501/1324]  >>>  loss : 7.185499	  train_loss : 7.161
EPOCH 15 [601/1324]  >>>  loss : 7.098188	  train_loss : 7.164
EPOCH 15 [701/1324]  >>>  loss : 7.222246	  train_loss : 7.164
EPOCH 15 [801/1324]  >>>  loss : 7.372873	  train_loss : 7.165
EPOCH 15 [901/1324]  >>>  loss : 7.123748	  train_loss : 7.165
EPOCH 15 [1001/1324]  >>>  loss : 7.151630	  train_loss : 7.169
EPOCH 15 [1101/1324]  >>>  loss : 7.218566	  train_loss : 7.169
EPOCH 15 [1201/1324]  >>>  loss : 7.146011	  train_loss : 7.171
EPOCH 15 [1301/1324]  >>>  loss : 7.112636	  train_loss : 7.172
EPOCH 15  >>>  loss : 6.862979	  train_loss : 7.171


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 16 [1/1324]  >>>  loss : 7.205949	  train_loss : 7.206
EPOCH 16 [101/1324]  >>>  loss : 6.802911	  train_loss : 7.140
EPOCH 16 [201/1324]  >>>  loss : 7.006626	  train_loss : 7.145
EPOCH 16 [301/1324]  >>>  loss : 7.030201	  train_loss : 7.143
EPOCH 16 [401/1324]  >>>  loss : 7.321146	  train_loss : 7.146
EPOCH 16 [501/1324]  >>>  loss : 7.047882	  train_loss : 7.143
EPOCH 16 [601/1324]  >>>  loss : 7.251835	  train_loss : 7.143
EPOCH 16 [701/1324]  >>>  loss : 7.183836	  train_loss : 7.145
EPOCH 16 [801/1324]  >>>  loss : 7.036782	  train_loss : 7.149
EPOCH 16 [901/1324]  >>>  loss : 7.281355	  train_loss : 7.147
EPOCH 16 [1001/1324]  >>>  loss : 7.259498	  train_loss : 7.147
EPOCH 16 [1101/1324]  >>>  loss : 7.360819	  train_loss : 7.147
EPOCH 16 [1201/1324]  >>>  loss : 7.099405	  train_loss : 7.150
EPOCH 16 [1301/1324]  >>>  loss : 7.078187	  train_loss : 7.149
EPOCH 16  >>>  loss : 7.172024	  train_loss : 7.149


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 17 [1/1324]  >>>  loss : 7.143270	  train_loss : 7.143
EPOCH 17 [101/1324]  >>>  loss : 6.796942	  train_loss : 7.124
EPOCH 17 [201/1324]  >>>  loss : 7.097301	  train_loss : 7.124
EPOCH 17 [301/1324]  >>>  loss : 6.974536	  train_loss : 7.116
EPOCH 17 [401/1324]  >>>  loss : 7.336038	  train_loss : 7.119
EPOCH 17 [501/1324]  >>>  loss : 7.033988	  train_loss : 7.117
EPOCH 17 [601/1324]  >>>  loss : 7.167417	  train_loss : 7.118
EPOCH 17 [701/1324]  >>>  loss : 7.130301	  train_loss : 7.121
EPOCH 17 [801/1324]  >>>  loss : 7.127720	  train_loss : 7.122
EPOCH 17 [901/1324]  >>>  loss : 7.144053	  train_loss : 7.121
EPOCH 17 [1001/1324]  >>>  loss : 7.178192	  train_loss : 7.122
EPOCH 17 [1101/1324]  >>>  loss : 7.078286	  train_loss : 7.124
EPOCH 17 [1201/1324]  >>>  loss : 7.004801	  train_loss : 7.124
EPOCH 17 [1301/1324]  >>>  loss : 7.154905	  train_loss : 7.125
EPOCH 17  >>>  loss : 7.037308	  train_loss : 7.125


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 18 [1/1324]  >>>  loss : 7.225228	  train_loss : 7.225
EPOCH 18 [101/1324]  >>>  loss : 7.137934	  train_loss : 7.073
EPOCH 18 [201/1324]  >>>  loss : 6.996492	  train_loss : 7.084
EPOCH 18 [301/1324]  >>>  loss : 7.251484	  train_loss : 7.084
EPOCH 18 [401/1324]  >>>  loss : 6.958508	  train_loss : 7.088
EPOCH 18 [501/1324]  >>>  loss : 7.300373	  train_loss : 7.092
EPOCH 18 [601/1324]  >>>  loss : 7.030933	  train_loss : 7.090
EPOCH 18 [701/1324]  >>>  loss : 7.429233	  train_loss : 7.094
EPOCH 18 [801/1324]  >>>  loss : 6.797106	  train_loss : 7.093
EPOCH 18 [901/1324]  >>>  loss : 7.083065	  train_loss : 7.095
EPOCH 18 [1001/1324]  >>>  loss : 6.984829	  train_loss : 7.096
EPOCH 18 [1101/1324]  >>>  loss : 7.170963	  train_loss : 7.097
EPOCH 18 [1201/1324]  >>>  loss : 7.223985	  train_loss : 7.100
EPOCH 18 [1301/1324]  >>>  loss : 7.025172	  train_loss : 7.102
EPOCH 18  >>>  loss : 7.042017	  train_loss : 7.102


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 19 [1/1324]  >>>  loss : 7.158290	  train_loss : 7.158
EPOCH 19 [101/1324]  >>>  loss : 7.139378	  train_loss : 7.059
EPOCH 19 [201/1324]  >>>  loss : 7.166185	  train_loss : 7.064
EPOCH 19 [301/1324]  >>>  loss : 7.155019	  train_loss : 7.063
EPOCH 19 [401/1324]  >>>  loss : 7.067724	  train_loss : 7.062
EPOCH 19 [501/1324]  >>>  loss : 7.050869	  train_loss : 7.063
EPOCH 19 [601/1324]  >>>  loss : 7.097532	  train_loss : 7.069
EPOCH 19 [701/1324]  >>>  loss : 7.351290	  train_loss : 7.068
EPOCH 19 [801/1324]  >>>  loss : 7.043020	  train_loss : 7.069
EPOCH 19 [901/1324]  >>>  loss : 7.118125	  train_loss : 7.072
EPOCH 19 [1001/1324]  >>>  loss : 7.005401	  train_loss : 7.072
EPOCH 19 [1101/1324]  >>>  loss : 7.062823	  train_loss : 7.073
EPOCH 19 [1201/1324]  >>>  loss : 7.072440	  train_loss : 7.075
EPOCH 19 [1301/1324]  >>>  loss : 7.271628	  train_loss : 7.078
EPOCH 19  >>>  loss : 7.256765	  train_loss : 7.079


  0%|          | 0/1324 [00:00<?, ?it/s]

  """
  """


EPOCH 20 [1/1324]  >>>  loss : 7.154314	  train_loss : 7.154
EPOCH 20 [101/1324]  >>>  loss : 7.040008	  train_loss : 7.024
EPOCH 20 [201/1324]  >>>  loss : 7.018530	  train_loss : 7.027
EPOCH 20 [301/1324]  >>>  loss : 7.040117	  train_loss : 7.034
EPOCH 20 [401/1324]  >>>  loss : 7.071582	  train_loss : 7.037
EPOCH 20 [501/1324]  >>>  loss : 7.141993	  train_loss : 7.039
EPOCH 20 [601/1324]  >>>  loss : 6.939913	  train_loss : 7.040
EPOCH 20 [701/1324]  >>>  loss : 7.094103	  train_loss : 7.043
EPOCH 20 [801/1324]  >>>  loss : 7.122448	  train_loss : 7.048
EPOCH 20 [901/1324]  >>>  loss : 7.019066	  train_loss : 7.048
EPOCH 20 [1001/1324]  >>>  loss : 7.075564	  train_loss : 7.048
EPOCH 20 [1101/1324]  >>>  loss : 7.305051	  train_loss : 7.050
EPOCH 20 [1201/1324]  >>>  loss : 7.143764	  train_loss : 7.052
EPOCH 20 [1301/1324]  >>>  loss : 7.251099	  train_loss : 7.055
EPOCH 20  >>>  loss : 6.955148	  train_loss : 7.056
