In [1]:
import os
import argparse
import copy
import json
import pickle
import Bio
from Bio.Align import substitution_matrices
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from abnumber import Chain
from evaluation.datasets import SAbDabDataset
from evaluation.datasets import get_dataset
from evaluation.utils.protein.writers import save_pdb
from evaluation.utils.data import *
from evaluation.utils.misc import *
from evaluation.utils.transforms import *
from xlm.utils import AttrDict
from xlm.utils import bool_flag, initialize_exp
from xlm.data.dictionary import Dictionary
from xlm.model.transformer import TransformerModel
from xlm.utils import to_cuda
from xlm.model.transformer import get_masks
from xlm.evaluation.evaluator import convert_to_text, calculate_identity
from transformers import BertModel
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

import pyrosetta
pyrosetta.init(silent=True)

from pyrosetta import pose_from_pdb, init
# from pyrosetta.rosetta import *
# from pyrosetta.teaching import *

#Core Includes
from rosetta.core.select import residue_selector as selections

from rosetta.protocols import antibody
init('-use_input_sc -ignore_unrecognized_res -check_cdr_chainbreaks false \
     -ignore_zero_occupancy false -load_PDB_components false -no_fconfig', silent=True)


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Translate sentences")

    # main parameters
    parser.add_argument("--dump_path", type=str, default="dumped/", help="Experiment dump path")
    parser.add_argument("--exp_name", type=str, default="kir", help="Experiment name")
    parser.add_argument("--exp_id", type=str, default="123", help="Experiment ID")
    parser.add_argument("--beam_size", type=int, default=100)
    parser.add_argument("--excess_res", type=int, default=50)
    parser.add_argument("--reporter", type=bool, default=False)
    # model / output paths
    parser.add_argument("--model_path", type=str, default="/scratch/ssd004/scratch/benjami/AntiXLM/evaluation/checkpoint_rand.pkl", help="Model path")
    parser.add_argument("--output_path", type=str, default="evaluation/", help="Output path")

    # source language / target language
    parser.add_argument("--src_lang", type=str, default="ag", help="Source language")
    parser.add_argument("--tgt_lang", type=str, default="ab", help="Target language")

    parser.add_argument('-i', '--index', type=int, default=0)
    parser.add_argument('-t', '--tag', type=str, default='')
    parser.add_argument('-c', '--config', type=str, default='evaluation/configs/test/codesign_single.yml')
    parser.add_argument('-o', '--out_root', type=str, default='evaluation/results')

    return parser

class SabdabEntry:
    def __init__(self, dataset, index, params, renumber=None) -> None:
        self.structure = dataset[index]
        self.entry = self.find_entry(dataset, index)
        
        self.structure_id = self.structure['id']
        self.ag_name = self.entry['ag_name']
        self.ag_chain = self.entry['ag_chains'][0]
        self.ab_chain = self.entry['H_chain']
        self.pdb_code = self.entry['pdbcode']

        self.f1 = self.structure['heavy']['FW1_seq']
        self.f2 = self.structure['heavy']['FW2_seq']
        self.f3 = self.structure['heavy']['FW3_seq']
        self.f4 = self.structure['heavy']['FW4_seq']

        self.c1 = self.structure['heavy']['H1_seq']
        self.c2 = self.structure['heavy']['H2_seq']
        self.c3 = self.structure['heavy']['H3_seq']
        
        self.excess = params.excess_res

        self.ab_seq = self.structure['heavy'].seq
        self.ag_seq = self.structure['antigen'].seq
        self.weights = {}
        self.set_weights()
        data_native = MergeChains()(self.structure)
        self.log_dir = get_new_log_dir(os.path.join(params.log_dir), prefix='%02d_%s' % (index, self.structure_id))
        save_pdb(data_native, os.path.join(self.log_dir, 'reference.pdb'))
        save_pdb(data_native, os.path.join(self.log_dir, 'reference_renamed.pdb'), 
                 rename={self.ag_chain:'A', self.ab_chain:'H'})
        pose = pose_from_pdb(os.path.join(self.log_dir, 'reference_renamed.pdb'))
        ab_info = antibody.AntibodyInfo(pose, antibody.Chothia_Scheme, antibody.North)

        for s in range(5,25):
            self.epi_residues = np.array(antibody.select_epitope_residues(ab_info, pose, s))[len(self.ab_seq):]
            if self.epi_residues.sum() > 20:
                break
        self.epi_range = (np.argmax(self.epi_residues), self.epi_residues.shape[0] - np.argmax(self.epi_residues[::-1]) - 1)
        self.epi_resseq = self.structure['antigen']['resseq'][self.epi_residues]

        save_pdb(data_native, os.path.join(self.log_dir, 'antigen.pdb'), ignore_chain=self.ab_chain)
        save_pdb(data_native, os.path.join(self.log_dir, 'antibody.pdb'), ignore_chain=self.ag_chain)        
        save_pdb(data_native, os.path.join(self.log_dir, 'cutted_antigen.pdb'), ignore_chain=self.ab_chain,
                 write_range={self.ag_chain: (max(0, self.epi_range[0]-self.excess), self.epi_range[1]+self.excess)})        
        save_pdb(data_native, os.path.join(self.log_dir, 'cutted_refrence_renamed.pdb'),
                 write_range={'A': (max(0, self.epi_range[0]-self.excess), self.epi_range[1]+self.excess)},
                 rename={self.ag_chain:'A', self.ab_chain:'H'})

        self.identity = {}
        self.generated_sequences = {}

    @property
    def antigen(self):
        extra = int(min(self.excess, (240 - (self.epi_range[1] - self.epi_range[0]))/2))
        return self.ag_seq[max(0, self.epi_range[0]-extra): self.epi_range[1]+extra]

    @property
    def antibody(self):
        return self.ab_seq

    def set_weights(self):
        self.weights['CDR1'] = self._construct_weight(cdr1=True)
        self.weights['CDR2'] = self._construct_weight(cdr2=True)
        self.weights['CDR3'] = self._construct_weight(cdr3=True)
        self.weights['CDR123'] = self._construct_weight(cdr1=True, cdr2=True, cdr3=True)

    def _construct_weight(self, cdr1=False, cdr2=False, cdr3=False):
        return [0] * len(self.f1) + \
            ([1] * len(self.c1) if cdr1 else [0] * len(self.c1)) + \
            [0] * len(self.f2) + \
            ([1] * len(self.c2) if cdr2 else [0] * len(self.c2)) + \
            [0] * len(self.f3) + \
            ([1] * len(self.c3) if cdr3 else [0] * len(self.c3)) + \
            [0] * len(self.f4)


    def find_entry(self, dataset:SAbDabDataset, index):
        for entry in dataset.sabdab_entries:
            if entry['id'] == self.structure['id']:
                return entry

    def write_generated(self):
        for key in self.generated_sequences:
        # Create a file name for the fasta file
            with open(os.path.join(self.log_dir, key)+'.fasta', "w") as f:
                for i, seq in enumerate(self.generated_sequences[key]):
                    # Open the file for writing
                    f.write(">{0}_sequence".format(key) + str(i) + "\n")
                    # Write the sequence to the file in fasta format
                    f.write(seq.replace(' ', '') + "\n")

    

def get_model(params):    
    # initialize the experiment
    # logger = initialize_exp(params)

    # generate parser / parse parameters
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    # logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=False).cuda().eval()
    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    encoder.load_state_dict({k[7:]:reloaded['encoder'][k] for k in reloaded['encoder']})
    decoder.load_state_dict({k[7:]:reloaded['decoder'][k] for k in reloaded['decoder']})
    params.src_id = model_params.lang2id['ag']
    params.tgt_id = model_params.lang2id['ab']
    encoder.eval()
    decoder.eval()
    bert = BertModel.from_pretrained("Rostlab/prot_bert")
    bert.eval()
    bert = bert.cuda()
    return encoder, decoder, dico, bert


def get_sabdab(params):
    # Load configs
    config, config_name = load_config(params.config)
    # Testset
    dataset = get_dataset(config.dataset.test)
    # Logging

    return dataset


def build_batch(seq, lang, eos, bos, pad):
    lengths = torch.LongTensor([len(seq) + 2])
    batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(pad)
    batch[0] = bos
    batch[1:lengths[0] - 1, 0].copy_(seq)
    batch[lengths[0] - 1, 0] = eos
    langs = batch.clone().fill_(lang)
    return batch, lengths, langs


def write_generated(log_dir, eval_step, sequences):
    log = get_new_log_dir(os.path.join(log_dir, eval_step))
    for i, seq in enumerate(sequences):
    # Create a file name for the fasta file
        filename = "sequence" + str(i) + ".fasta"

        # Open the file for writing
        with open(os.path.join(log, filename), "w") as f:
            # Write the sequence to the file in fasta format
            f.write(">sequence" + str(i) + "\n")
            f.write(seq.replace(' ', '') + "\n")

def evaluate(encoder, decoder, dico, bert, params, aligner, sample:SabdabEntry, eval_modes=['CDR1', 'CDR2', 'CDR3', 'CDR123', 'GEN']):
    
    ag_tensor, ag_length_tensor, ag_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antigen]), 
                                                               params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
    ab_tensor, ab_length_tensor, ab_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antibody]), 
                                                               params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
    token_type_ids = torch.zeros_like(ag_tensor)
            
    ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids = \
        to_cuda(ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids)

    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):  
        with torch.no_grad(): 
            bert_embed = bert(input_ids=ag_tensor.T, token_type_ids=token_type_ids.T, attention_mask=get_masks(ag_tensor.size()[0], ag_length_tensor, False)[0]).last_hidden_state

            enc1 = encoder('fwd', x=ag_tensor, lengths=ag_length_tensor, langs=ag_langs_tensor, causal=False, bert_embed=bert_embed)
            enc1 = enc1.transpose(0, 1)
    beam_size = params.beam_size
    for eval_step in eval_modes:
        if eval_step == 'GEN':
            beam_size = 20
        ab_weights_tensor, _, _ = build_batch(torch.LongTensor(sample.weights.get(eval_step, sample.weights['CDR123'])), params.tgt_id, 0, 0, 0)
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):   
            with torch.no_grad(): 
                generated, lengths = decoder.generate_beam(
                                    enc1, ag_length_tensor, params.tgt_id, beam_size=beam_size,
                                    length_penalty=1,
                                    early_stopping=False,
                                    max_len=160,
                                    bert_embed=bert_embed,
                                    cdr_generation=eval_step != 'GEN',
                                    tgt_frw=ab_tensor,
                                    w=ab_weights_tensor,
                                    open_end=False
                                )
        hypothesis_text = convert_to_text(generated, lengths, dico, params)
        batch_generate_identity, batch_generate_cdr_identity = calculate_identity(aligner, [sample.antibody], hypothesis_text, 'ab', ab_weights_tensor, beam_size=beam_size)
        write_generated(sample.log_dir, eval_step, hypothesis_text)
        sample.generated_sequences[eval_step] = hypothesis_text
        
        if eval_step == 'GEN':
            sample.identity[eval_step+'_CDR'] = batch_generate_cdr_identity
            sample.identity[eval_step+'_ALL'] = batch_generate_identity
        else:
            sample.identity[eval_step] = batch_generate_cdr_identity

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


core.init: Checking for fconfig files in pwd and ./rosetta/flags
core.init: Rosetta version: PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python39.Release r337 2022.49+release.201d763 201d7639f91f369d58b1adf514f3febaf6154c58 http://www.pyrosetta.org 2022-12-07T16:15:33
core.init: command: PyRosetta -ex1 -ex2aro -database /h/benjami/.conda/envs/AntiXLM/lib/python3.9/site-packages/pyrosetta/database
basic.random.init_random_generator: 'RNG device' seed mode, using '/dev/urandom', seed=964659220 seed_offset=0 real_seed=964659220 thread_index=0
basic.random.init_random_generator: RandomGenerator:init: Normal mode, seed=964659220 RG_type=mt19937
core.init: Rosetta version: PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python39.Release r337 2022.49+release.201d763 201d7639f91f369d58b1adf514f3febaf6154c58 http://www.pyrosetta.org 2022-12-07T16:15:33
core.init: command: PyRosetta -use_input_sc -ignore_unrecognized_res -check_cdr_chainbreaks false -ignore_zero_occupancy fal

  from rosetta.core.select import residue_selector as selections


In [2]:
parser = get_parser()
params = parser.parse_args([])
log = get_new_log_dir(os.path.join(params.out_root), date=True)
params.log_dir = log
dataset = get_sabdab(params)
test_samples = []

encoder, decoder, dico, bert = get_model(params)
for param in bert.parameters():
    param.requires_grad = False
for param in encoder.parameters():
    param.requires_grad = False
for param in decoder.parameters():
    param.requires_grad = False
samples = dict()
identity = {'CDR1':0, 'CDR2':0, 'CDR3':0}
i = 0
sample = SabdabEntry(dataset=dataset, index=i, params=params)
 
 

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


core.chemical.GlobalResidueTypeSet: Finished initializing fa_standard residue type set.  Created 985 residue types
core.chemical.GlobalResidueTypeSet: Total time to initialize 3.14974 seconds.
core.import_pose.import_pose: File 'evaluation/results/2023_04_30__16_50_20/00_4ffv_H_L_A_/reference_renamed.pdb' automatically determined to be of type PDB
core.conformation.Conformation: Found disulfide between residues 288 299
core.conformation.Conformation: current variant for 288 CYS
core.conformation.Conformation: current variant for 299 CYS
core.conformation.Conformation: current variant for 288 CYD
core.conformation.Conformation: current variant for 299 CYD
core.conformation.Conformation: Found disulfide between residues 345 357
core.conformation.Conformation: current variant for 345 CYS
core.conformation.Conformation: current variant for 357 CYS
core.conformation.Conformation: current variant for 345 CYD
core.conformation.Conformation: current variant for 357 CYD
core.conformation.Confor

In [16]:
ag_tensor, ag_length_tensor, ag_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antigen]), 
                                                           params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
ab_tensor, ab_length_tensor, ab_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antibody]), 
                                                           params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
token_type_ids = torch.zeros_like(ag_tensor)

ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids = \
    to_cuda(ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids)

In [17]:
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):  
    bert_embed = bert(input_ids=ag_tensor.T, token_type_ids=token_type_ids.T, attention_mask=get_masks(ag_tensor.size()[0], ag_length_tensor, False)[0]).last_hidden_state

    enc1 = encoder('fwd', x=ag_tensor, lengths=ag_length_tensor, langs=ag_langs_tensor, causal=False, bert_embed=bert_embed)
    enc1 = enc1.transpose(0, 1)
ab_weights_tensor, _, _ = build_batch(torch.LongTensor(sample.weights.get('CDR3', sample.weights['CDR123'])), params.tgt_id, 0, 0, 0)
ab_weights_tensor,  = to_cuda(ab_weights_tensor)

In [341]:
class Optimize(torch.nn.Module):
    def __init__(self, encoder, decoder, bert, sample, params, dico):
        super(Optimize, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.bert_word_embedding = bert.embeddings.word_embeddings
        self.bert = bert
        ag_tensor, ag_length_tensor, ag_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antigen]), 
                                                           params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
        ab_tensor, ab_length_tensor, ab_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antibody]), 
                                                           params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
        ab_weights_tensor, _, _ = build_batch(torch.LongTensor(sample.weights.get('CDR3', sample.weights['CDR123'])), params.tgt_id, 0, 0, 0)
        self.ab_weights_tensor,  = to_cuda(ab_weights_tensor)

        self.lr = 100

        self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor= \
        to_cuda(ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor)
        
        
    def encode(self, input_ids, length_tensor, langs, one_hot=None):
        token_type_ids = torch.zeros_like(input_ids)
        token_type_ids = token_type_ids.cuda()
        if one_hot is None:
            one_hot = self.one_hot(input_ids)
        bert_embed_x = self.one_hot_to_embed(one_hot, self.bert_word_embedding)
        encoder_embed_x = self.one_hot_to_embed(one_hot, self.encoder.embeddings)

        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):  
            bert_embed = self.bert(inputs_embeds=bert_embed_x, token_type_ids=token_type_ids.T, attention_mask=get_masks(input_ids.size()[0], length_tensor, False)[0]).last_hidden_state

            enc1 = self.encoder('fwd', x=input_ids, lengths=length_tensor, langs=langs, causal=False, bert_embed=bert_embed, embed_x=encoder_embed_x)
            enc1 = enc1.transpose(0, 1)
        return enc1, bert_embed
      
        
    def decode(self, input_ids, length_tensor, langs, enc, enc_len, bert_embed, one_hot=None):
        alen = torch.arange(length_tensor.max(), dtype=torch.long, device=input_ids.device)
        pred_mask = (alen[:, None] < length_tensor[None] - 1)[:-1]   # do not predict anything given the last target word
        pred_mask = pred_mask.expand([input_ids.shape[0]-1,1])
        y = input_ids[1:].masked_select(pred_mask)
        if one_hot is None:
            one_hot = self.one_hot(input_ids)
        embed_x = self.one_hot_to_embed(one_hot, self.decoder.embeddings)
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
            dec = self.decoder('fwd', x=input_ids, lengths=length_tensor, causal=True, langs=langs,
                          src_enc=enc, src_len=enc_len, bert_embed=bert_embed,
                          embed_x=embed_x
                        )
            scores, loss = decoder('predict', tensor=dec, pred_mask=pred_mask, y=y, get_scores=True)
            scores = F.log_softmax(scores, dim=-1).exp()
            loss = -scores[torch.arange(input_ids.shape[0]-1),y].mean()


        return loss, scores

    def one_hot(self, tensor, noise=False):
        one_hot_tensor = F.one_hot(tensor, num_classes=30).type(torch.float)
        if noise:
            one_hot_tensor = self.normal_label_smoothing(one_hot_tensor, self.ab_weights_tensor)
        return one_hot_tensor
    
    def one_hot_to_embed(self, tensor, module):
        return tensor.matmul(module.weight).transpose(0,1)

    def normal_noise(self, one_hot, weight, epsilon=10, mean=0.0, std_dev=100):
        normal_distribution = torch.normal(mean, std_dev, one_hot.shape)
        normal_distribution = torch.abs(normal_distribution)
        normal_distribution = normal_distribution / normal_distribution.sum(2).unsqueeze(2).expand(normal_distribution.shape)  # Normalize the distribution to sum to 1
        new_tensor = epsilon * normal_distribution.cuda()
        result = one_hot.clone()

        result[(weight == 1).flatten()] = new_tensor[(weight == 1).flatten()]
        # result[(weight == 1).flatten(),0, :5] = 0
        return result
    
    def normal_label_smoothing(self, one_hot, weight, epsilon=0.2, mean=0.0, std_dev=1):
        normal_distribution = torch.normal(mean, std_dev, one_hot.shape)
        normal_distribution = torch.abs(normal_distribution)
        normal_distribution = normal_distribution / normal_distribution.sum(2).unsqueeze(2).expand(normal_distribution.shape)  # Normalize the distribution to sum to 1
        new_tensor = (1 - epsilon) * one_hot + epsilon * normal_distribution.cuda()
        result = one_hot.clone()

        result[(weight == 1).flatten()] = new_tensor[(weight == 1).flatten()]
        return result
        
    def one_hot_ab(self):
        return self.normal_label_smoothing(self.one_hot(self.ab_tensor), self.ab_weights_tensor).detach().type(torch.float).requires_grad_()

    def forward(self, ab_onehot):
        temp = ab_onehot.clone()
        temp[(self.ab_weights_tensor == 1).flatten(),0,:5] = 0

        enc, bert = self.encode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor)
        loss, scores = self.decode(temp.argmax(-1), self.ab_length_tensor, self.ab_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert, F.softmax(ab_onehot, -1))
        loss = torch.matmul(scores, F.softmax(ab_onehot, -1).squeeze(1).T+1e-5)[torch.arange(118), torch.arange(118)].mean()
        print(loss.item())

        Frobenius_loss = torch.norm(F.softmax(ab_onehot, -1) - self.one_hot(self.ab_tensor), dim=2).mean()
        
        print(Frobenius_loss)
        l1_norms = torch.norm(F.softmax(ab_onehot, -1), p=1, dim=-1)
        reg_term = torch.abs(l1_norms)
        print(reg_term.mean())
        loss = loss #+ 0.05*reg_term.mean()
        loss.backward()


        changed, new_ab_onehot = self.update(ab_onehot)
        print(changed)
        with torch.no_grad():
            print('new onehot ab-ag')
            enc, bert = self.encode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, new_ab_onehot)
            loss, scores = self.decode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())

            print('new ag-ab')
            temp = new_ab_onehot.clone()
            temp[(self.ab_weights_tensor == 1).flatten(),0,:5] = 0

            enc, bert = self.encode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor)
            loss, scores = self.decode(temp.argmax(-1), self.ab_length_tensor, self.ab_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())

            print('original ab-ag')
            enc, bert = self.encode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor)
            loss, scores = self.decode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())


            print('original ag-ab')
            enc, bert = self.encode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor)
            loss, scores = self.decode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)

            print(loss.item())
            
        # new_ab_onehot[(self.ab_weights_tensor == 1).flatten()] = (new_ab_onehot/new_ab_onehot.sum(-1).unsqueeze(2))[(self.ab_weights_tensor == 1).flatten()]
        # new_ab_onehot[new_ab_onehot < 0] = 0
        return new_ab_onehot

    def update(self, ab_onehot):
        new_ab_onehot = ab_onehot - \
            ((ab_onehot.grad * self.lr) * self.ab_weights_tensor.unsqueeze(2))
        changed = not (self.ab_tensor == new_ab_onehot.argmax(-1)).all()
        # new_ab_onehot[(self.ab_weights_tensor == 1).flatten(),0,:5] = 0
        # self.smoothed_weight = self.sum_one(new_ab_onehot).detach().requires_grad_()
        return changed, new_ab_onehot




In [342]:
o = Optimize(encoder, decoder, bert, sample, params, dico)
o.ab_weights_tensor[100:] = 0
ab_onehot = o.normal_noise(o.one_hot_ab(), o.ab_weights_tensor).detach().type(torch.float)
ab_onehot[(o.ab_weights_tensor == 0).flatten()] = ab_onehot[(o.ab_weights_tensor == 0).flatten()] * 100
ab_onehot = ab_onehot.requires_grad_()
# o.ab_tensor[101] = 10
F.softmax(ab_onehot[99:107],-1).max(-1)

torch.return_types.max(
values=tensor([[0.0614],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[20],
        [13],
        [14],
        [ 8],
        [19],
        [19],
        [14],
        [ 8]], device='cuda:0'))

In [376]:
ab_onehot = o(ab_onehot.detach().type(torch.float).requires_grad_())

0.0775437131524086
tensor(0.0084, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1., device='cuda:0', grad_fn=<MeanBackward0>)
True
new onehot ab-ag
-0.05587519332766533
new ag-ab
-0.9163196682929993
original ab-ag
-0.05872029438614845
original ag-ab
-0.9474751353263855


In [377]:
F.softmax(ab_onehot[99:107],-1).max(-1)

torch.return_types.max(
values=tensor([[0.0657],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[18],
        [13],
        [14],
        [ 8],
        [19],
        [19],
        [14],
        [ 8]], device='cuda:0'))

In [371]:
print(ab_onehot[99])
F.softmax(ab_onehot,-1)[99,0,18:20]

tensor([[ 0.1769,  0.5026,  0.5507,  0.1713,  0.8702,  0.3840,  0.4714,  0.0934,
         -0.0408,  0.0294, -0.0264,  0.1005,  0.1775, -0.0373,  0.0919,  0.2429,
          0.3876, -0.0404,  0.9800,  0.0295,  0.8210,  0.3156,  0.2815,  0.4407,
         -0.0431,  0.5557,  0.9140,  0.3180,  0.9323,  0.3493]],
       device='cuda:0', grad_fn=<SelectBackward0>)


tensor([0.0604, 0.0234], device='cuda:0', grad_fn=<SliceBackward0>)

In [664]:
l1_norms = torch.norm(ab_onehot, p=1, dim=2)
reg_term = torch.sum(torch.abs(l1_norms - 1))
reg_term


tensor(128.8144, device='cuda:0', grad_fn=<SumBackward0>)

In [297]:
print(torch.logical_not((o.ab_tensor == ab_onehot.argmax(-1)).flatten()).nonzero(as_tuple=True))
print(ab_onehot[99])
print(ab_onehot[99].argmax())
print(o.ab_tensor[99])

(tensor([ 99, 101], device='cuda:0'),)
tensor([[-1.3193, -0.7160, -0.5079, -0.0474, -1.2667,  0.0269,  0.1399,  0.1745,
          6.9504,  3.3534,  1.0780,  2.7917,  0.1616,  0.3097, -0.0782, -0.1693,
          1.2231, -0.2720,  0.6096, -0.5265, -0.2374, -0.6325,  0.5630,  0.5040,
          0.7064,  0.1829, -1.3335, -0.5932,  0.1758, -1.2510]],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor(8, device='cuda:0')
tensor([19], device='cuda:0')


In [507]:
np.array([dico.id2word[s.item()] for s in o.ab_tensor])

array(['[CLS]', 'E', 'F', 'Q', 'L', 'Q', 'Q', 'S', 'G', 'P', 'E', 'L',
       'V', 'K', 'P', 'G', 'A', 'S', 'V', 'K', 'I', 'S', 'C', 'K', 'A',
       'S', 'G', 'Y', 'S', 'F', 'T', 'D', 'Y', 'N', 'I', 'N', 'W', 'M',
       'K', 'Q', 'S', 'N', 'G', 'K', 'S', 'L', 'E', 'W', 'I', 'G', 'V',
       'V', 'I', 'P', 'K', 'Y', 'G', 'T', 'T', 'N', 'Y', 'N', 'Q', 'K',
       'F', 'Q', 'G', 'K', 'A', 'T', 'L', 'T', 'V', 'D', 'Q', 'S', 'S',
       'S', 'T', 'A', 'Y', 'I', 'Q', 'L', 'N', 'S', 'L', 'T', 'S', 'E',
       'D', 'S', 'A', 'V', 'Y', 'Y', 'C', 'T', 'R', 'F', 'R', 'D', 'V',
       'F', 'F', 'D', 'V', 'W', 'G', 'T', 'G', 'T', 'T', 'V', 'T', 'V',
       'S', 'S', '[SEP]'], dtype='<U5')

In [508]:
np.array([dico.id2word[s.item()] for s in ab_onehot.argmax(-1)])

array(['[CLS]', 'E', 'F', 'Q', 'L', 'Q', 'Q', 'S', 'G', 'P', 'E', 'L',
       'V', 'K', 'P', 'G', 'A', 'S', 'V', 'K', 'I', 'S', 'C', 'K', 'A',
       'S', 'G', 'Y', 'S', 'F', 'T', 'D', 'Y', 'N', 'I', 'N', 'W', 'M',
       'K', 'Q', 'S', 'N', 'G', 'K', 'S', 'L', 'E', 'W', 'I', 'G', 'V',
       'V', 'I', 'P', 'K', 'Y', 'G', 'T', 'T', 'N', 'Y', 'N', 'Q', 'K',
       'F', 'Q', 'G', 'K', 'A', 'T', 'L', 'T', 'V', 'D', 'Q', 'S', 'S',
       'S', 'T', 'A', 'Y', 'I', 'Q', 'L', 'N', 'S', 'L', 'T', 'S', 'E',
       'D', 'S', 'A', 'V', 'Y', 'Y', 'C', 'T', 'R', 'R', 'R', 'D', 'V',
       'F', 'C', 'I', 'V', 'W', 'G', 'T', 'G', 'T', 'T', 'V', 'T', 'V',
       'S', 'S', '[SEP]'], dtype='<U5')

In [510]:
F.gumbel_softmax(ab_onehot[99])

tensor([[0.0040, 0.0041, 0.0021, 0.0028, 0.0009, 0.0054, 0.0195, 0.0010, 0.0026,
         0.0041, 0.0010, 0.0015, 0.0067, 0.0159, 0.0012, 0.6727, 0.0016, 0.0038,
         0.1036, 0.0022, 0.0019, 0.0025, 0.0032, 0.0357, 0.0167, 0.0491, 0.0096,
         0.0029, 0.0151, 0.0065]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [98]:

def update(input_tensor, ab_weights_tensor):
    new_tensor = input_tensor.clone().detach()
    print('******')
    step = 0.1
    while (F.one_hot(new_tensor.argmax(2), num_classes=30) == input_tensor).all() and \
        (new_tensor.argmax(2)[new_tensor.argmax(2).flatten() != input_tensor.argmax(2).flatten()] > 4).all() and \
            (new_tensor.argmax(2)[new_tensor.argmax(2).flatten() != input_tensor.argmax(2).flatten()] < 25).all():
        new_tensor = new_tensor + ((input_tensor.grad * step) * ab_weights_tensor.transpose(0,1).unsqueeze(2).expand(input_tensor.shape).cuda())
        step += 0.1
        print(1)
    # print((new_tensor.flatten() != input_tensor.flatten()).argmax())
    print(input_tensor.argmax(2)[new_tensor.argmax(2).flatten() != input_tensor.argmax(2).flatten()])

    print(new_tensor.argmax(2)[new_tensor.argmax(2).flatten() != input_tensor.argmax(2).flatten()])

    return new_tensor.argmax(2)
        

alen = torch.arange(ab_length_tensor.max(), dtype=torch.long, device=ab_length_tensor.device)
pred_mask = (alen[:, None] < ab_length_tensor[None] - 1)[:-1]   # do not predict anything given the last target word
y = ab_tensor[1:].masked_select(pred_mask)
proj.weight = torch.nn.Parameter(decoder.embeddings.weight[5:25, :])

input_tensor = ab_tensor
input_tensor = F.one_hot(input_tensor, num_classes=30).type(torch.float).detach().requires_grad_()
# embed_x = decoder.embeddings.weight[input_tensor,:].transpose(0,1).detach().requires_grad_()

for i in range(2):
    embed_x = input_tensor.matmul(decoder.embeddings.weight).transpose(0,1)

    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True): 
        dec = decoder('fwd', x=ab_tensor, lengths=ab_length_tensor, langs=ab_tensor.clone().fill_(params.tgt_id), causal=True,
                            src_enc=enc1, src_len=ag_length_tensor, bert_embed=bert_embed, embed_x=embed_x
                        )
        scores, loss = decoder('predict', tensor=dec, pred_mask=pred_mask, y=y, get_scores=True)
        s = F.log_softmax(scores, dim=-1)
        loss = s[torch.arange(s.shape[0]), input_tensor.argmax(2)[1:].flatten()].sum()
        print(loss)
        loss.backward()
        input_tensor = update(input_tensor, ab_weights_tensor.transpose(0,1))
        

NameError: name 'proj' is not defined

In [50]:
input_tensor.shape

torch.Size([119, 1, 30])

In [19]:
input_tensor.flatten() == ab_tensor.flatten()

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True],
       device='cuda:0')

In [131]:
((input_tensor.grad * 1e0) * ab_weights_tensor.transpose(0,1).unsqueeze(2).expand(input_tensor.shape).cuda())[0,100]

tensor([-0.0013,  0.1574, -0.0698,  ..., -0.1786,  0.0265,  0.0289],
       device='cuda:0')

In [11]:
b = torch.nn.Sequential(
    bert.embeddings.position_embeddings,
    bert.embeddings.token_type_embeddings,
    bert.embeddings.LayerNorm,
    bert.embeddings.dropout,
    bert.encoder,
    bert.pooler
)

In [10]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30, 1024, padding_idx=0)
    (position_embeddings): Embedding(40000, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-29): 30 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, inpla

In [260]:
ab_tensor.flatten() == input_tensor.flatten()

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True],
       device='cuda:0')

In [73]:
t

tensor(0., device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)

In [None]:
self.slen = ab_tensor.shape[0]
#         self.s = 1
#         self.ab_tensor = ab_tensor
#         self.ab_tensor_batch = self.ab_tensor.expand([self.slen,self.s])
#         self.ab_weight = ab_weight.detach()
#         self.ab_length = ab_length.detach()
#         self.ab_length_batch = self.ab_length.expand([self.s])

#         self.src_enc = src_enc.expand([self.s, src_enc.shape[1], 1024]).detach()
#         self.src_len = src_len.expand([self.s]).detach()
#         self.bert_embed = bert_embed.expand([self.s, src_enc.shape[1], 1024]).detach()
#         self.alen = torch.arange(ab_length.max(), dtype=torch.long, device=ab_length.device)
#         self.pred_mask = (self.alen[:, None] < ab_length[None] - 1)[:-1]   # do not predict anything given the last target word
#         self.pred_mask = self.pred_mask.expand([self.slen-1,self.s])
#         self.y = ab_tensor[1:].masked_select(self.pred_mask)
#         self.weight = F.one_hot(ab_tensor, num_classes=30).type(torch.float)
#         self.weight = self.weight.expand([self.slen,self.s,30])
#         self.smoothed_weight = self.normal_label_smoothing(self.weight, self.ab_weight)
#         # self.weight = self.weight + (torch.normal(0, .2, self.weight.shape).cuda()* self.ab_weight.unsqueeze(2).expand(self.weight.shape))
#         self.weight = self.weight.detach().requires_grad_()
#         self.last_weight = self.weight.detach().clone()
#         self.smoothed_weight = self.smoothed_weight.detach().requires_grad_()

#         self.softmax = torch.nn.Softmax(dim=2)
#         self.proj = decoder.embeddings.weight
#         self.first = True
#         self.lr = 1e-2
#         self.loss = 0

    def sum_one(self, tensor):
        return tensor / tensor.sum(2).unsqueeze(2).expand(tensor.shape)
    
    def forward(self, decoder):
        weight = self.smoothed_weight
        
        # weight[(optimizer.ab_weight==1)][:,5:25] = self.linear.weight
        # weight = self.weight
        # loss_sum = -torch.abs((self.weight.sum(dim=2) - 1)).mean(1).mean()
        # loss_reg = torch.norm(weight * self.ab_weight.unsqueeze(2).cuda(), p=2, dim=2).mean()

        # loss_max = (self.weight.max(dim=2)[0] - 1).mean(1).mean()
        loss_reg = -(self.smoothed_weight ** 2).sum(2).mean()
        embed = weight.matmul(decoder.embeddings.weight).transpose(0,1)
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
            dec = decoder('fwd', x=self.ab_tensor_batch, lengths=self.ab_length_batch,
                          langs=self.ab_tensor_batch.clone().fill_(params.tgt_id), causal=True,
                          src_enc=self.src_enc, src_len=self.src_len, bert_embed=self.bert_embed,
                          embed_x=embed
                        )
            scores, loss = decoder('predict', tensor=dec, pred_mask=self.pred_mask, y=self.y, get_scores=True)
            scores = F.log_softmax(scores, dim=-1)
            loss = scores[torch.arange(118), self.y].sum()
            slen = scores.shape[0]
            # loss = weight[1:].view(-1,30).matmul(scores.transpose(1,0)).view((slen,slen)).mul(torch.eye(slen).cuda())
            # loss = loss[torch.arange((self.slen-1)*self.s), torch.arange((self.slen-1)*self.s)].view([self.s,-1]).mean(1).mean()

        # print(scores[torch.arange(scores.shape[0]), optimizer.smoothed_weight[1:].argmax(2).flatten()].sum())    
        # return loss
        # loss = -F.cross_entropy(scores[-1].unsqueeze(0), self.weight[-1])
        print(loss)
        # one_hot_target = torch.zeros_like(self.weight).scatter_(1, target_index.unsqueeze(1), 1)
        total_loss = loss_reg + loss.mean()
        print(loss_reg)

        if self.first:
            self.loss = loss
            self.first = False
            
        total_loss.backward()
        flag = self.update()        
        self.new_ab = self.smoothed_weight.argmax(2)
        return flag,loss.mean()
        
    
    def update(self):
        # print(self.weight.grad[100])
        new_weight = self.smoothed_weight + \
            ((self.smoothed_weight.grad * self.lr) * self.ab_weight.unsqueeze(2).expand(self.weight.shape))
        changed = (self.ab_tensor == new_weight.argmax(-1)).all()

        self.smoothed_weight = self.sum_one(new_weight).detach().requires_grad_()
        return changed
    
    def one_hot(self):
        t = self.smoothed_weight.clone()
        t[2:-2, 0, :5] = 0
        new_weight = F.one_hot(t.argmax(dim=2), num_classes=30).type(torch.float).detach().requires_grad_()        
        self.smoothed_weight = self.normal_label_smoothing(new_weight, self.ab_weight).detach().requires_grad_()
        self.y = self.new_ab[1:].masked_select(self.pred_mask)


# ab_test = ab_tensor.clone()
# ab_test[100] = 5
# optimizer = Optimize(decoder, ab_test, ab_weights_tensor, ab_length_tensor, enc1, ag_length_tensor, bert_embed)
# opt = torch.optim.Adam(optimizer.parameters())


In [None]:
    
    def forward_reverse(self, ab_onehot):
        enc, bert = self.encode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, ab_onehot)
        loss = self.decode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
        print(loss.item())
        
        kl_loss = F.kl_div((ab_onehot+1e-3).log(), (self.one_hot(self.ab_tensor)+1e-5).type(torch.float), reduction='batchmean')

        # kl_loss = F.kl_div(ab_onehot, self.one_hot(self.ab_tensor), reduction='batchmean')
        print(kl_loss)
        l1_norms = torch.norm(ab_onehot, p=1, dim=2)
        reg_term = torch.abs(l1_norms - 1)
        # print(reg_term[90:110])
        loss = loss + 0.005*kl_loss #+ 0.5*reg_term.sum()
        loss.backward()


        changed, new_ab_onehot = self.update(ab_onehot)
        print(changed)
        with torch.no_grad():
            print('new onehot ab-ag')
            enc, bert = self.encode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, new_ab_onehot)
            loss = self.decode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())

            print('new ag-ab')
            temp = new_ab_onehot.clone()
            temp[(self.ab_weights_tensor == 1).flatten(),0,:5] = 0

            enc, bert = self.encode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor)
            loss = self.decode(temp.argmax(-1), self.ab_length_tensor, self.ab_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())

            print('original ab-ag')
            enc, bert = self.encode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor)
            loss = self.decode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)
            print(loss.item())


            print('original ag-ab')
            enc, bert = self.encode(self.ag_tensor, self.ag_length_tensor, self.ag_langs_tensor)
            loss = self.decode(self.ab_tensor, self.ab_length_tensor, self.ab_langs_tensor, enc, torch.tensor([enc.shape[1]]).cuda(), bert)

            print(loss.item())
            
        new_ab_onehot[(self.ab_weights_tensor == 1).flatten()] = (new_ab_onehot/new_ab_onehot.sum(-1).unsqueeze(2))[(self.ab_weights_tensor == 1).flatten()]
        new_ab_onehot[new_ab_onehot < 0] = 0
        return new_ab_onehot
