In [None]:
import torch
from pytorch_metric_learning import losses
import data_handler
from siamese_network import SiameseNetwork, train
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from custom_losses import ContrastiveLoss

In [None]:
torch.zeros(1).cuda()
#print(f"torch version: {torch.__version__}")

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"torch cuda available: {torch.cuda.is_available()}")

In [None]:
df_data, _ = data_handler.load(path="dataset/", filename_train="train.csv", sep_char='#')

In [None]:
df_train, df_val = data_handler.split_train_data(df_data, perc_split=0.8)

In [None]:
df_train = data_handler.concatenate_topics(df_train)
df_val = data_handler.concatenate_topics(df_val)

In [None]:
tokenized = data_handler.tokenize_df(df_train[:100], BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True))

In [None]:
model = SiameseNetwork(bert_type=BertModel.from_pretrained('bert-base-uncased'))

train_loader = DataLoader(tokenized, shuffle=False, batch_size=32)

#train_loss = ContrastiveLoss()
train_loss = losses.ContrastiveLoss()

optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )

# Batch size: 16, 32
# Learning rate (Adam): 5e-5, 3e-5, 2e-5
# Number of epochs: 2, 3, 4

# The BERT authors recommend between 2 and 4.
epochs = 1

# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_loader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)
for epoch in range(1, epochs + 1):
    encoding = train(model, None, train_loader, train_loss, optimizer, epoch, scheduler)
    #test(model, device, test_loader)
