diff --git a/.gitignore b/.gitignore index 75352a8..5489622 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ +runs/ +/model/ diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py new file mode 100644 index 0000000..7d6fdca --- /dev/null +++ b/alphafold2_pytorch/transformer.py @@ -0,0 +1,102 @@ +###################################################################### +# Transformer! +# ------------ +# +# Transformer is a Seq2Seq model introduced in `“Attention is all you +# need” `__ +# paper for solving machine translation task. Transformer model consists +# of an encoder and decoder block each containing fixed number of layers. +# +# Encoder processes the input sequence by propogating it, through a series +# of Multi-head Attention and Feed forward network layers. The output from +# the Encoder referred to as ``memory``, is fed to the decoder along with +# target tensors. Encoder and decoder are trained in an end-to-end fashion +# using teacher forcing technique. +# + +import math +import torch +from torch import nn +from torch import Tensor +from torch.nn import (TransformerEncoder, TransformerDecoder, + TransformerEncoderLayer, TransformerDecoderLayer) + + +class Seq2SeqTransformer(nn.Module): + def __init__(self, num_encoder_layers: int, num_decoder_layers: int, + emb_size: int, src_vocab_size: int, tgt_vocab_size: int, + dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.0, activation: str = "relu", + max_len: int = 5000): + super(Seq2SeqTransformer, self).__init__() + encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head, + dim_feedforward=dim_feedforward, dropout=dropout, activation=activation) + self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers) + decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=num_head, + dim_feedforward=dim_feedforward, dropout=dropout, activation=activation) + self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers) + + self.generator = nn.Linear(emb_size, tgt_vocab_size) + self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + self.tgt_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout, maxlen=max_len) + + # todo make mask work + def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, + tgt_mask: Tensor, src_padding_mask: Tensor, + tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor, + use_padding_mask: bool = True): + src_emb = self.positional_encoding(self.src_tok_emb(src)) + tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) + if use_padding_mask: + memory = self.transformer_encoder(src_emb, src_key_padding_mask=src_padding_mask) + outs = self.transformer_decoder(tgt_emb, memory, tgt_key_padding_mask=tgt_padding_mask) + else: + memory = self.transformer_encoder(src_emb) + outs = self.transformer_decoder(tgt_emb, memory) + return self.generator(outs) + + def encode(self, src: Tensor, src_mask: Tensor): + return self.transformer_encoder( + self.src_tok_emb(src), src_mask) + + def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): + return self.transformer_decoder(self.positional_encoding( + self.tgt_tok_emb(tgt)), memory, + tgt_mask) + + +###################################################################### +# Text tokens are represented by using token embeddings. Positional +# encoding is added to the token embedding to introduce a notion of word +# order. +# + +class PositionalEncoding(nn.Module): + def __init__(self, emb_size: int, dropout, maxlen: int = 5000): + super(PositionalEncoding, self).__init__() + den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) + pos = torch.arange(0, maxlen).reshape(maxlen, 1) + pos_embedding = torch.zeros((maxlen, emb_size)) + pos_embedding[:, 0::2] = torch.sin(pos * den) + pos_embedding[:, 1::2] = torch.cos(pos * den) + pos_embedding = pos_embedding.unsqueeze(-2) + + self.dropout = nn.Dropout(dropout) + self.register_buffer('pos_embedding', pos_embedding) + + def forward(self, token_embedding: Tensor): + return self.dropout(token_embedding + + self.pos_embedding[:token_embedding.size(0), :]) + + +class TokenEmbedding(nn.Module): + def __init__(self, vocab_size: int, emb_size): + super(TokenEmbedding, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + self.emb_size = emb_size + + def forward(self, tokens: Tensor): + return self.embedding(tokens.long()) * math.sqrt(self.emb_size) + + + diff --git a/setup.py b/setup.py index 8d5e8c3..469cf53 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,9 @@ 'transformers', 'tqdm', 'biopython', - 'mp-nerf>=0.1.5' + 'mp-nerf>=0.1.5', + 'tensorboard', + 'matplotlib' ], setup_requires=[ 'pytest-runner', diff --git a/train_end2end.py b/train_end2end.py index bb6692d..96db543 100644 --- a/train_end2end.py +++ b/train_end2end.py @@ -1,5 +1,6 @@ import torch from torch.optim import Adam +from torch import nn from torch.utils.data import DataLoader import torch.nn.functional as F from einops import rearrange @@ -7,7 +8,7 @@ # data import sidechainnet as scn -from sidechainnet.sequence.utils import VOCAB +# from sidechainnet.sequence.utils import VOCAB from sidechainnet.structure.build_info import NUM_COORDS_PER_RES # models @@ -108,11 +109,11 @@ def cycle(loader, cond = lambda x: True): # mask the atoms and backbone positions for each residue # sequence embedding (msa / esm / attn / or nothing) - msa, embedds = None + msa, embedds = None, None # get embedds if FEATURES == "esm": - embedds = get_esm_embedd(seq, embedd_model, batch_converter) + embedds = get_esm_embedd(seq, embedd_model, batch_converter, device=DEVICE) # get msa here elif FEATURES == "msa": pass diff --git a/train_simple.py b/train_simple.py new file mode 100644 index 0000000..1c88687 --- /dev/null +++ b/train_simple.py @@ -0,0 +1,480 @@ +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from einops import rearrange + +import sidechainnet as scn +from sidechainnet.dataloaders.collate import prepare_dataloaders +from alphafold2_pytorch import Alphafold2 +import alphafold2_pytorch.constants as constants +from alphafold2_pytorch.utils import get_bucketed_distance_matrix +from alphafold2_pytorch.transformer import Seq2SeqTransformer +import time +import os +import matplotlib.pyplot as plt + +# constants + +DEVICE = None # defaults to cuda if available, else cpu +NUM_EPOCHS = int(3e5) +NUM_BATCHES = int(1e5) +GRADIENT_ACCUMULATE_EVERY = 16 +LEARNING_RATE = 1e-6 +IGNORE_INDEX = 20 +THRESHOLD_LENGTH = 100 +BATCH_SIZE = 100 + +# transformer constants + +SRC_VOCAB_SIZE = 21 # number of amino acids + padding 20 +TGT_VOCAB_SIZE = 3 # backbone torsion angle +NUM_ENCODER_LAYERS = 6 +NUM_DECODER_LAYERS = 6 +EMB_SIZE = 512 +NUM_HEAD = 8 +FFN_HID_DIM = 1024 +LOSS_WITHOUT_PADDING = False +warmup_steps = 4000 +DROPOUT = 0.1 + +MODEL_NAME = f"model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_dropout{DROPOUT}_warmup{warmup_steps}" +MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}.pt" +BEST_MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_best.pt" +# set device +try: + os.makedirs(f'graph/{MODEL_NAME}/') +except: + print(f'graph/{MODEL_NAME}/ aleardy exist') + + +DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS +DEVICE = constants.DEVICE + +graph_interval = 5 +# helpers + + +def cycle(loader, cond=lambda x: True): + while True: + for data in loader: + if not cond(data): + continue + yield data + + +def filter_dictionary_by_seq_length(raw_data, seq_length_threshold, portion): + """Filter SidechainNet data by removing poor-resolution training entries. + + Args: + raw_data (dict): SidechainNet dictionary. + seq_length_threshold (int): sequence length threshold + + Returns: + Filtered dictionary. + """ + new_data = { + "seq": [], + "ang": [], + "ids": [], + "evo": [], + "msk": [], + "crd": [], + "sec": [], + "res": [] + } + train = raw_data[portion] + n_filtered_entries = 0 + total_entires = 0. + for seq, ang, crd, msk, evo, _id, res, sec in zip(train['seq'], train['ang'], + train['crd'], train['msk'], + train['evo'], train['ids'], + train['res'], train['sec']): + total_entires += 1 + if len(seq) > seq_length_threshold: + n_filtered_entries += 1 + continue + else: + new_data["seq"].append(seq) + new_data["ang"].append(ang[:, 0:3]) + new_data["ids"].append(_id) + new_data["evo"].append(evo) + new_data["msk"].append(msk) + new_data["crd"].append(crd) + new_data["sec"].append(sec) + new_data["res"].append(res) + if n_filtered_entries: + print( + f"{portion}: {total_entires - n_filtered_entries:.0f} out of {total_entires:.0f} ({(total_entires - n_filtered_entries) / total_entires:.1%})" + f" training set entries were included if sequence length <= {seq_length_threshold}") + raw_data[portion] = new_data + return raw_data + + +def create_mask(src, tgt): + src_padding_mask = (src == IGNORE_INDEX).transpose(0, 1) + tgt_padding_mask = (tgt == IGNORE_INDEX).transpose(0, 1) + return src_padding_mask, tgt_padding_mask + + +def train_epoch(model, train_iter, optimizer_, epoch): + model.train() + losses = 0 + radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + for idx, (batch) in enumerate(train_iter): + seq, coords, angs, mask = batch.int_seqs, batch.crds, batch.angs, batch.msks + + b, l = seq.shape + + # prepare mask, labels + + seq, coords, angs, mask = seq.to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + DEVICE).bool() + # seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) + coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) + # if not LOSS_WITHOUT_PADDING: + # angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) + # mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + + # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) + src_padding_mask, tgt_padding_mask = create_mask(seq, seq) + + # predict + + logits = model(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + + optimizer_.zero_grad() + + mask1 = mask.unsqueeze(2).expand(-1, -1, 3) + angs1 = torch.acos(torch.zeros(1)).item() * 4 * \ + (angs < -torch.acos(torch.zeros(1)).item() * 1.5) + \ + angs + + angs2 = mask1 * angs1 + logits2 = mask1 * logits + angs3 = angs2.reshape(-1, angs2.shape[-1]) + logits3 = logits2.reshape(-1, logits2.shape[-1]) + + # loss + if LOSS_WITHOUT_PADDING: + loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) + else: + loss_ = loss_fn(torch.masked_select(logits, mask1), torch.masked_select(angs1, mask1)) + diff = logits3 - angs3 + diff = F.pad(diff, (0, 0, 0, (THRESHOLD_LENGTH - l)*BATCH_SIZE), value=0) + radian_diff = torch.rad2deg(diff).reshape(-1) + radian_diffs += abs(radian_diff) + logits3 = F.pad(logits3, (0, 0, 0, (THRESHOLD_LENGTH - l) * BATCH_SIZE), value=0) + logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) + angs3 = F.pad(angs3, (0, 0, 0, (THRESHOLD_LENGTH - l) * BATCH_SIZE), value=0) + angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) + + # plt.plot(logits3.tolist(), label='logits') + if idx == 0 and epoch % graph_interval == 0: + offset = torch.randint(0, b, (1,))*THRESHOLD_LENGTH + plt.clf() + plt.plot(angs3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi') + plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_phi.png") + plt.clf() + plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') + plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_psi.png") + plt.clf() + plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') + plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_omega.png") + # plt.plot(diff.tolist()) + + loss_.backward() + + # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer_.step() + losses += loss_.item() + radian_diffs = radian_diffs / len(train_iter) + logits_avg = logits_avg / len(train_iter) + angs_avg = angs_avg / len(train_iter) + # diff_dict = {str(i): string for i, string in enumerate(radian_diffs.tolist())} + # writer_train.add_scalars("train", diff_dict, epoch) + if epoch % graph_interval == 0: + plt.clf() + plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') + plt.plot(torch.mean(logits_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='logit') + plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/train_{epoch}.png") + return losses / len(train_iter) + + +def evaluate(model, val_iter, split_): + model.eval() + losses = 0 + radian_diffs = None # torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + logits_avg = None + angs_avg = None + for idx, (batch) in (enumerate(val_iter)): + seq, coords, angs, mask = batch.int_seqs, batch.crds, batch.angs, batch.msks + + b, l = seq.shape + if radian_diffs is None: + radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) + logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) + angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) + # prepare mask, labels + + seq, coords, angs, mask = seq.to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + DEVICE).bool() + seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) + coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) + if not LOSS_WITHOUT_PADDING: + angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) + mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + + # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) + src_padding_mask, tgt_padding_mask = create_mask(seq, seq) + + # predict + + logits = model(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + + mask1 = mask.unsqueeze(2).expand(-1, -1, 3) + angs1 = torch.acos(torch.zeros(1)).item() * 4 * \ + (angs < -torch.acos(torch.zeros(1)).item() * 1.5) + \ + angs + + angs2 = mask1 * angs1 + logits2 = mask1 * logits + angs3 = angs2.reshape(-1, angs2.shape[-1]) + logits3 = logits2.reshape(-1, logits2.shape[-1]) + + # loss + if LOSS_WITHOUT_PADDING: + loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) + else: + loss_ = loss_fn(torch.masked_select(logits, mask1), torch.masked_select(angs1, mask1)) + diff = logits3 - angs3 + radian_diff = torch.rad2deg(diff).reshape(-1) + radian_diffs += abs(radian_diff) + logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) + angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) + + if epoch % graph_interval == 0: + offset = torch.randint(0, b, (1,)) * THRESHOLD_LENGTH + plt.clf() + plt.plot(angs3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi') + plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_phi_{split_}_{idx}.png") + plt.clf() + plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') + plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_psi_{split_}_{idx}.png") + plt.clf() + plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') + plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_omega_{split_}_{idx}.png") + + losses += loss_.item() + radian_diffs = radian_diffs / len(val_iter) + logits_avg = logits_avg / len(val_iter) + angs_avg = angs_avg / len(val_iter) + if epoch % graph_interval == 0: + plt.clf() + plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') + plt.plot(torch.mean(logits_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='logit') + plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/{MODEL_NAME}/valid_{epoch}.png") + return losses / len(val_iter) + + +# get data + +raw_data = scn.load( + casp_version=12, + thinning=30, + batch_size=BATCH_SIZE, + dynamic_batching=False +) + +filtered_raw_data = filter_dictionary_by_seq_length(raw_data, THRESHOLD_LENGTH, "train") +writer_train = SummaryWriter(f"runs/{MODEL_NAME}/train") +# writer_train_eval = SummaryWriter("runs/train_eval") +writer_valid = SummaryWriter(f"runs/{MODEL_NAME}/validation") +writer_best = SummaryWriter(f"runs/{MODEL_NAME}/best") +# writer_valids = [] +for split in scn.utils.download.VALID_SPLITS: + filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') +# writer_valids.append(SummaryWriter(f"runs/{split}")) +data = prepare_dataloaders( + filtered_raw_data, + aggregate_model_input=True, + batch_size=BATCH_SIZE, + num_workers=4, + seq_as_onehot=None, + collate_fn=None, + dynamic_batching=False, + optimize_for_cpu_parallelism=False, + train_eval_downsample=.2) +dl = iter(data['train']) + +# model + +# model = Alphafold2( +# dim=256, +# depth=1, +# heads=8, +# dim_head=64 +# ).to(DEVICE) + +# +transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS, + emb_size=EMB_SIZE, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, + dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD, activation='gelu', max_len=5000, + dropout=DROPOUT) + +# optimizer + +for p in transformer.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + +transformer = transformer.to(DEVICE) + +loss_fn = torch.nn.MSELoss() + +optimizer = torch.optim.Adam( + transformer.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9 +) +# optimizer = torch.optim.RMSprop( +# transformer.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False +# ) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2000, verbose=True, factor=0.75) + +prev_epoch = 0 + + +def restore_model(model_path, model, optimizer_, restore_optim=False, restore=True): + prev_epoch_ = 0 + loss_ = 1e10 + valid_loss_ = 1e10 + if os.path.exists(model_path): + checkpoint = torch.load(model_path) + if restore: + model.load_state_dict(checkpoint['model_state_dict']) + if restore_optim: + optimizer_.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + prev_epoch_ = checkpoint['epoch'] + loss_ = checkpoint['loss'] + if 'valid_loss' in checkpoint: + valid_loss_ = checkpoint['valid_loss'] + if restore: + print(f"restore checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") + else: + print(f"best checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") + return prev_epoch_, loss_, valid_loss_ + + +prev_epoch, loss, valid_loss = restore_model(MODEL_PATH, transformer, optimizer) +best_valid = valid_loss if valid_loss < 1e10 else 1e10 +_, _, valid_restore = restore_model(BEST_MODEL_PATH, transformer, optimizer, restore=False) +if valid_restore < best_valid: + best_valid = valid_restore +# training loop +not_improved_count = 1 +restore_epoch = 101 +for epoch in range(prev_epoch + 1, NUM_EPOCHS + 1): + if epoch > warmup_steps and not_improved_count % restore_epoch == 0: + not_improved_count = 1 + restore_model(BEST_MODEL_PATH, transformer, optimizer) + learning_rate = pow(EMB_SIZE, -0.5)*min(pow(epoch, -0.5), epoch*pow(warmup_steps, -1.5)) + for g in optimizer.param_groups: + g['lr'] = learning_rate + start_time = time.time() + train_loss = train_epoch(transformer, iter(data['train']), optimizer, epoch) + end_time = time.time() + # train_eval_loss = evaluate(transformer, iter(data['train-eval'])) + valid_count = 0 + val_loss_sum = 0 + for split in scn.utils.download.VALID_SPLITS: + val_loss = evaluate(transformer, iter(data[f'{split}']), split) + # writer_valids[valid_count].add_scalar("loss", val_loss, epoch) + # writer_valids[valid_count].flush() + # print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}") + valid_count += 1 + val_loss_sum += val_loss + print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, val loss: {val_loss_sum / valid_count:.3f}, " + f"Epoch time = {(end_time - start_time):.3f}s learning rate: {learning_rate}")) + writer_train.add_scalar("loss", train_loss, epoch) + writer_train.flush() + writer_valid.add_scalar("loss", val_loss_sum / valid_count, epoch) + writer_valid.flush() + # writer_train_eval.add_scalar("loss", train_eval_loss, epoch) + # writer_train_eval.flush() + scheduler.step(val_loss_sum / valid_count) + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + 'valid_loss': val_loss_sum / valid_count, + }, MODEL_PATH) + if val_loss_sum / valid_count < best_valid: + best_valid = val_loss_sum / valid_count + save_path = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}_{best_valid:.3f}.pt" + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + 'valid_loss': best_valid, + }, BEST_MODEL_PATH) + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + 'valid_loss': best_valid, + }, save_path) + print(f"new best checkpoint. Epoch: {epoch}, loss: {train_loss:.3f}, valid_loss: {best_valid:.3f}") + writer_best.add_scalar("loss", best_valid, epoch) + writer_best.flush() + elif epoch > warmup_steps: + not_improved_count += 1 +print('train ended') +writer_train.close() +writer_valid.close() +# valid_count = 0 +# for split in scn.utils.download.VALID_SPLITS: +# writer_valids[valid_count].close() +# valid_count += 1