In [1]:
%load_ext autoreload
%autoreload 2

import sys

import torch

if r"../../../kb-data-cleaning/kbclean" not in sys.path:
    sys.path.append(r"../../../kb-data-cleaning/kbclean")

method = "gan"

## Load hyper-parameters for experiments

In [2]:
import yaml

hparams = yaml.load(open(f"../../config/{method}.yaml", "r"), Loader=yaml.FullLoader)
hparams

{'batch_size': 3500,
 'gen_emb_dim': 32,
 'dis_emb_dim': 64,
 'filter_sizes': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20],
 'num_filters': [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160],
 'dropout_p': 0.75,
 'lr': 0.0005,
 'amp_level': 'O1',
 'max_length': 100}

In [3]:
from argparse import Namespace

import torch

hparams = Namespace(**hparams)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preprocess data

In [4]:
import pandas as pd

df = pd.read_csv("../../data/train/sherlock-all/first100k.csv")

data = df["data"].values.tolist()

In [5]:
data[:5], labels[:5]

(['Story',
  'Print, billboards, press ads, ambient',
  'Story',
  'Print, billboards, press ads, ambient',
  'Print, billboards, press ads, ambient'],
 ['type', 'type', 'type', 'type', 'type'])

In [6]:
import string

import regex as re
from torchnlp.encoders.text import CharacterEncoder

char_encoder = CharacterEncoder(data, append_eos=True)
hparams.vocab_size = char_encoder.vocab_size
hparams.vocab_size

In [8]:
from torch.utils.data import DataLoader, SequentialSampler, random_split
from torchnlp.encoders.text import stack_and_pad_tensors
from torchnlp.samplers import BucketBatchSampler

data = [example[:100] in data if example]

def collate_fn(batch):
    return char_encoder.batch_encode(batch)

train_length = int(len(data) * 0.7)
train_dataset, val_dataset = random_split(
    list(data), [train_length, len(data) - train_length],
)

len(train_dataset)

995037

In [9]:
train_dataloader = DataLoader(
    train_dataset, batch_size=hparams.batch_size, collate_fn=collate_fn, num_workers=16,
)

val_dataloader = DataLoader(
    val_dataset, batch_size=hparams.batch_size, collate_fn=collate_fn, num_workers=16,
)

## Load pre-trained encoder + regular-GAN discriminator

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger

from utils.logger import MyTensorBoardLogger
from models.gan import RGANDiscriminator, OneClassGAN

encoder = 

rgan = RGANDiscriminator(hparams, device)

trainer = Trainer(
    gpus=[0, 1, 2, 3],
    amp_level=hparams.amp_level,
    benchmark=False,
    default_save_path="../../checkpoints/sherlock/",
    distributed_backend="dp",
    logger=MyTensorBoardLogger("../../tt_logs", "rgan"),
)
trainer.fit(
    trainer, train_dataloader=train_dataloader
)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO:lightning:
   | Name                           | Type      | Params
---------------------------------------------------------
0  | attn                           | Attention | 7 K   
1  | attn.attn                      | Linear    | 7 K   
2  | attn.v                         | Linear    | 50    
3  | encoder                        | Encoder   | 78 K  
4  | encoder.embedding              | Embedding | 43 K  
5  | encoder.rnn                    | GRU       | 30 K  
6  | encoder.fc                     | Linear    | 5 K   
7  | encoder.dropout                | Dropout   | 0     
8  | decoder                        | Decoder   | 254 K 
9  | decoder.embedding              | Embedding | 43 K  
10 | decoder.rnn                    | GRU       | 30 K  
11 | decoder.fc_out                 | Linear    | 173 K 
12 | decoder.dropout                | Dropout   | 0     
13 | hidden2latent                

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('Nen eer er n  eeeenn ', 'North American Green'), ('Soe    e         r    </s>', 'Smet, Raoul de more...'), ('F</s>', 'F'), ('—oonnnin</s></s>', '— Wyoming'), ('38</s></s>', '385'), ('Stanaatatal </s></s></s></s></s>', "Sportsman's Park"), ('Ea a a a</s></s></s></s></s></s></s>', 'Engel Hall 102'), ('20001100110', '2011-03-02'), ('0.0</s>', '0.0'), ('Ball</s></s></s>', 'Bengal')]




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('Stoughton</s>', 'Stoughton'), ('Sunday</s>', 'Sunday'), ('NC</s>', 'NC'), ('T11-1</s>', 'T 1-1'), ('Unin  o  iinnentontuuuutt</s>', 'Univ. of Minnesota-Duluth'), ('NAA</s>', 'NYA'), ('Adult</s>', 'Adult'), ('Reveree oortiage</s>', 'Reverse mortgage'), ('0.0</s>', '0.0'), ('Rap Accusation</s>', 'Rap Accusation')]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('UCF</s>', 'UCF'), ('3:39</s>', '3:39'), ('73 yrs 25 days</s>', '73 yrs 25 days'), ('Fluor Corp</s>', 'Fluor Corp'), ('OK14</s>', 'OK14'), ('FIU vs Stetson</s>', 'FIU vs Stetson'), ('single,</s>', 'single,'), ('Ireland</s>', 'Ireland'), ('Rep. Cynthia M  Lummis</s>', 'Rep. Cynthia M. Lummis'), ('0.0</s>', '0.0')]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('retep1944</s>', 'retep1944'), ('JR</s>', 'JR'), ('Barry McHugh</s>', 'Barry McHugh'), ('0.0</s>', '0.0'), ('MN</s>', 'MN'), ('Nhk</s>', 'Nfk'), ('2012</s>', '2012'), ('NORDIC AMER…</s>', 'NORDIC AMER…'), ("Angel's Prayer</s>", "Angel's Prayer"), ('Canada</s>', 'Canada')]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('2</s>', '2'), ('Georgia</s>', 'Georgia'), ('Tyler Flowers</s>', 'Tyler Flowers'), ('Bing Bang Boys</s>', 'Bing Bang Boys'), ('34</s>', '34'), ('Wellington</s>', 'Wellington'), ('1998</s>', '1998'), ('Review</s>', 'Review'), ('Atlanta, GA</s>', 'Atlanta, GA'), ('ISF</s>', 'ISF')]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('CANNONBALL</s>', 'CANNONBALL'), ('Combined sources Manullly validated informatinninffereer r   mmomomiiontoon fo feeeemeeatataand  omp</s>', 'Combined sources Manually validated information inferred from a combination of experimental and comp'), ('Toronto</s>', 'Toronto'), ('25</s>', '25'), ('Nsw</s>', 'Nsw'), ('20</s>', '20'), ('0.0</s>', '0.0'), ('Las Vegas</s>', 'Las Vegas'), ('DCAC</s>', 'DCAC'), ('Pacific Power Blue Sky</s>', 'Pacific Power Blue Sky')]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

[('Montpellier</s>', 'Montpellier'), ('0.0</s>', '0.0'), ('const</s>', 'const'), ('86</s>', '86'), ('Woodside</s>', 'Woodside'), ('278</s>', '278'), ('49</s>', '49'), ('18</s>', '18'), ('Gaby Sanchez</s>', 'Gaby Sanchez'), ('3</s>', '3')]


Process Process-251:
Process Process-254:
Process Process-252:
Process Process-253:
Process Process-248:
Process Process-250:
Process Process-247:
Process Process-255:
Process Process-242:
Process Process-249:
Process Process-256:
Process Process-246:
Process Process-241:
Traceback (most recent call last):
  File "/nas/home/minhpham/miniconda3/lib/python3.7/threading.py", line 1060, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9a2d01050>
Traceback (most recent call last):
  File "/nas/home/minhpham/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 962, in __del__
    self._shutdown_workers()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 942, in _shutdown_workers
    w.join()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 140, in join




Traceback (most recent call last):
Traceback (most recent call last):
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()


1

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/nas/home/minhpham/miniconda3/lib/python3.7/multiprocessing/process.py", line 300, in _boots