In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import logging

import torch
from torch.utils.data import DataLoader

from config import read_config, config_str
from data import TextDataset
from trainer import Trainer

import gc

In [3]:
logging.basicConfig(
    handlers=[logging.FileHandler("debug.log", mode='w'), logging.StreamHandler()],
    level=logging.INFO, 
    format='[%(asctime)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
 )

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

In [4]:
torch.cuda.empty_cache()
gc.collect()

config = read_config()

logging.info(config_str(config))

batch_size = config["batch_size"]

train_dataset = TextDataset(config, "train")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=TextDataset.get_collate_fn(), num_workers=config["num_workers"], pin_memory=True)

val_dataset = TextDataset(config, "val")
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=TextDataset.get_collate_fn(), num_workers=config["num_workers"], pin_memory=True)

max_length = 128
vocab_sizes = len(train_dataset.src_vocab), len(train_dataset.dst_vocab)     

logging.info("Vocab sizes {}".format(vocab_sizes))

device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
logging.info("Device {}".format(device))

[2023-03-05 03:25:54] 
datadir: data
vocabdir: vocab
dataset: 
	train: 
		src: train.de-en.de
		dst: train.de-en.en
	val: 
		src: val.de-en.de
		dst: val.de-en.en
	test: 
		src: test1.de-en.de
language: 
	src: de
	dst: en
model: 
	num_encoder_layers: 3
	num_decoder_layers: 3
	embedding_dim: 512
	num_heads: 8
	feedforward_dim: 512
	dropout: 0.1
batch_size: 64
num_workers: 4
optimizer: 
	lr: 1e-4
	beta1: 0.9
	beta2: 0.98
	eps: 1e-9
epochs: 10
device_ids: [0, 1, 2]
checkpoint: 
	dir: checkpoints
	step: 1
[2023-03-05 03:25:56] Loaded vocab train de/en from vocab/vocab.pth
[2023-03-05 03:25:56] Loaded vocab val de/en from vocab/vocab.pth
[2023-03-05 03:25:56] Vocab sizes (123554, 56326)
[2023-03-05 03:25:56] Device cuda:0


In [5]:
trainer = Trainer(config, vocab_sizes, max_length, device, run_name="test")
trainer.load_from_checkpoint(epoch=10)
#trainer.train(config, train_loader, val_loader)

[2023-03-05 03:26:33] Checkpoint is loaded from checkpoints/test/epoch_10.ckpt with val_loss 1.98711


In [8]:
from model import translate
from tqdm import tqdm

test_dataset = TextDataset(config, "test")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=TextDataset.get_collate_fn(is_test=True))
dst_vocab = train_dataset.vocabs["en"]

[2023-03-05 03:29:15] Loaded vocab test de/* from vocab/vocab.pth


In [11]:
translated = []

for src, _ in tqdm(test_loader):
    dst_tokens = list(translate(trainer.model, src[0], device).cpu().numpy())
    sentence = " ".join(dst_vocab.lookup_tokens(dst_tokens)).replace("<bos>", "").replace("<eos>", "").strip()
    translated.append(sentence)

100%|██████████| 2998/2998 [06:12<00:00,  8.04it/s]


In [12]:
with open("prediction.txt", "w") as f:
    for sentence in translated:
        f.write(sentence + "\n")