In [3]:
import sys, torch
import torch.nn as nn
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import time
import matplotlib
matplotlib.use('Tkagg')    ## to avoid focus stealing

In [38]:
np.random.seed(0)

In [39]:
data = pd.read_csv('AVPdb_data.csv', skiprows = 1, usecols = range(3), header=None, names=['ID','seq','len'])

In [40]:
data
#seq = np.asarray(data['seq'])
#print(seq[:5])

Unnamed: 0,ID,seq,len
0,AVP0001,PYVGSGLYRR,10
1,AVP0002,SMIENLEYM,9
2,AVP0003,ECRSTSYAGAVVNDL,15
3,AVP0004,STSYAGAVVNDL,12
4,AVP0005,YAGAVVNDL,9
...,...,...,...
2054,AVP2058,LFRLIKSLIKRLVSAFK,17
2055,AVP2059,SLIGGLVSAFK,11
2056,AVP2060,VSAFK,5
2057,AVP2061,KHMHWHPPALNT,12


In [41]:
#Dictionary of 20 canonical amino acids

CHAR_TO_IND = {'_': 0,
 'R': 1,
 'F': 2,
 'L': 3,
 'D': 4,
 'S': 5,
 'T': 6,
 'E': 7,
 'I': 8,
 'N': 9,
 'C': 10,
 'W': 11,
 'Y': 12,
 'A': 13,
 'V': 14,
 'P': 15,
 'G': 16,
 'Q': 17,
 'H': 18,
 'M': 19,
 'K': 20}

IND_TO_CHAR = {
    CHAR_TO_IND[c]: c
    for c in CHAR_TO_IND
}

In [58]:
MAX_PEPTIDE_LENGTH = 5 #Problem Statement requires < 2000 kDa, Avg. amino acid = 110 kDa
NUM_AMINO_ACIDS = len(CHAR_TO_IND)

In [59]:
#need to think about how to deal with smaller sequences
def peptide_to_vector(peptide):
    """Takes an input which is a string of amino acids ie 'AAYS' and returns an array of one-hot vectors of shape 
    (MAX_PEPTIDE_LENGTH,NUM_AMINO_ACIDS)"""
    default = np.zeros([MAX_SEQUENCE_LENGTH, len(CHAR_TO_IND)])
    for i, character in enumerate(peptide[:MAX_PEPTIDE_LENGTH]):
        default[i][CHAR_TO_IND[character]] = 1
    return default

#think about how to deal with non one hot vectors
def vector_to_peptide(one_hot):
    """Takes a one hot vector (MAX_PEPTIDE_LENGTH,NUM_AMINO_ACIDS) and returns the peptide it represents
    Note that argmax on equal values defaults to the smallest index"""
    if one_hot.ndim == 1:
        one_hot = one_hot.reshape((-1,NUM_AMINO_ACIDS))
    
    return ''.join([IND_TO_CHAR[one_hot[i].argmax()]for i in range(len(one_hot))])
        

In [60]:
seq = ['AAAAA','AAAII','AAAHH','AAAFF','AAARR','AAAGG','AAAYY','AAAKK','AAAQQ','AAAPP','AAAVV' ]

In [74]:
def data_function(examples,batch_size, iteration):
    input_array = []
    for j in range(iteration * batch_size, (iteration+1) * batch_size):
        embedding = peptide_to_vector(examples[j])
        embedding = embedding.reshape(-1)
        input_array.append(embedding)
    return input_array

#fix random blanks appearing in the middle
def noise_function(batch_size):
    
    input_array = []
    for j in range(batch_size):
        a = np.zeros([MAX_PEPTIDE_LENGTH,NUM_AMINO_ACIDS ])
        for i in range(MAX_PEPTIDE_LENGTH):
            
            x = np.random.randint(NUM_AMINO_ACIDS) 
            a[i][x] = np.random.normal(0.,1.)
        a = a.reshape(-1)
        input_array.append(a)
    return input_array

In [75]:
data_function(seq,2,0)[0].shape

(105,)

In [84]:
list(map(vector_to_peptide,noise_function(6)))

['__Q__', '___N_', 'H__TY', 'AIG__', '_DRYM', 'TNL__']

In [94]:
class Generator(nn.Module):
    def __init__(self, input_length,output_length):
        """A generator for mapping a random peptide to an antiviral peptide
        Args:
            input_length (int array): max_length * number_of_characters 
                                      ("noise vector")
            layers (List[int]): A list of layer widths including output width
            output_activation: torch activation function or None
        """
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(input_length, 1800)
        self.leaky_relu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(1800, 1440)
        self.linear3 = nn.Linear(1440, output_length)
        self.linear4 = nn.Linear(1080, 720)
        self.linear5 = nn.Linear(720, 360)
        self.output_activation = nn.Sigmoid()

    def forward(self, input_tensor):
        """Forward pass; map latent vectors to samples."""
        intermediate = self.linear1(input_tensor)
        intermediate = self.leaky_relu(intermediate)
        intermediate = self.linear2(intermediate)
        intermediate = self.relu(intermediate)
        intermediate = self.linear3(intermediate)
        """intermediate = self.leaky_relu(intermediate)
        intermediate = self.linear4(intermediate)
        intermediate = self.leaky_relu(intermediate)
        intermediate = self.linear5(intermediate)"""
        if self.output_activation is not None:
            intermediate = self.output_activation(intermediate)
        return intermediate

In [95]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, layers):
        """A discriminator for discerning real from generated samples.
        params:
            input_dim (int): width of the input
            layers (List[int]): A list of layer widths including output width
        Output activation is Sigmoid.
        """
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self._init_layers(layers)

    def _init_layers(self, layers):
        """Initialize the layers and store as self.module_list."""
        self.module_list = nn.ModuleList()
        last_layer = self.input_dim
        for index, width in enumerate(layers):
            self.module_list.append(nn.Linear(last_layer, width))
            last_layer = width
            if index + 1 != len(layers):
                self.module_list.append(nn.LeakyReLU())
            else:
                self.module_list.append(nn.Sigmoid())

    def forward(self, input_tensor):
        """Forward pass; map samples to confidence they are real [0, 1]."""
        intermediate = input_tensor
        for layer in self.module_list:
            intermediate = layer(intermediate)
        return intermediate

In [96]:
def simpleGAN(batch_size: int = 25, epochs: int = 5, max_data: int = len(seq), print_every: int = 10):

    #Array to monitor losses
    loss_g = []
    loss_d = []
    input_length = MAX_PEPTIDE_LENGTH*NUM_AMINO_ACIDS
    output_length = MAX_PEPTIDE_LENGTH*NUM_AMINO_ACIDS

    # Models
    generator = Generator(input_length,output_length)
    discriminator = Discriminator(input_length, [64, 32, 1])  

    # Optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.1)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.1)

    # loss
    loss = nn.BCELoss()

    for i in range(epochs):

        for j in range(int(max_data/batch_size)):
            # zero the gradients on each iteration
            generator_optimizer.zero_grad()


            # Create noisy input for generator
            # Need float type instead of int
            noise = noise_function(batch_size)
            noise_data = torch.tensor(noise).float()
            generated_data = generator(noise_data)
            #print("Generated data in loop")
            #print(generated_data)

            # Generate examples of even real data
            true_data = data_function(seq,batch_size, j)
            if i % print_every ==0:
                print("True data used: ", list(map(vector_to_peptide,true_data)) )
            true_labels = torch.tensor(np.ones(batch_size)).float()
            true_data = torch.tensor(true_data).float()

            # Train the generator
            # We invert the labels here and don't train the discriminator because we want the generator
            # to make things the discriminator classifies as true.
            generator_discriminator_out = discriminator(generated_data)
            generator_loss = loss(generator_discriminator_out, true_labels)
            generator_loss.backward()
            generator_optimizer.step()

            # Train the discriminator on the true/generated data
            discriminator_optimizer.zero_grad()
            true_discriminator_out = discriminator(true_data)
            true_discriminator_loss = loss(true_discriminator_out, true_labels)

            # add .detach() here think about this
            generator_discriminator_out = discriminator(generated_data.detach())
            generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
            discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2   
            discriminator_loss.backward()
            discriminator_optimizer.step()

        generated_data = generated_data.detach().numpy()
        if i % print_every == 0:
            print("Epoch: ",i,)
            print("generated data:")
            print(generated_data)
            print("peptide of noisedata[0]")
            print(vector_to_peptide(noise_data.numpy()[0]))
            print("peptide of generated data[0]")
            print(vector_to_peptide(generated_data[0]))
        """This threshold only makes sense if there's a sigmoid activation for the Generator"""
        #generated_data[generated_data > 0.5] = 1
        #generated_data[generated_data <= 0.5] = 0           
        sequence = np.reshape(generated_data[0], [MAX_PEPTIDE_LENGTH,NUM_AMINO_ACIDS])

        if i % print_every == 0:
            print("sequence:")
            print(sequence)
            print("peptide of sequence")
            print(vector_to_peptide(sequence))
            print("Generator Loss :",generator_loss.item())
            print("Discriminator Loss: ", discriminator_loss.item())

In [97]:
simpleGAN(batch_size=10, epochs=100, print_every = 20)

True data used:  ['AAAAA', 'AAAII', 'AAAHH', 'AAAFF', 'AAARR', 'AAAGG', 'AAAYY', 'AAAKK', 'AAAQQ', 'AAAPP']
Epoch:  0
generated data:
[[0.50014126 0.5030785  0.49608636 ... 0.50473934 0.49990073 0.50308037]
 [0.49648586 0.49489465 0.49600834 ... 0.49837855 0.4978191  0.5011022 ]
 [0.5016147  0.4979055  0.4948813  ... 0.497995   0.5013907  0.5027408 ]
 ...
 [0.5000933  0.5041413  0.49377877 ... 0.49386957 0.49332958 0.49800912]
 [0.48935634 0.4971755  0.5030058  ... 0.49008027 0.5044937  0.5099868 ]
 [0.49490184 0.4989096  0.50248766 ... 0.4921616  0.50085753 0.5068339 ]]
peptide of noisedata[0]
K_IAA
peptide of generated data[0]
HHFGT
sequence:
[[0.50014126 0.5030785  0.49608636 0.50269157 0.5046573  0.50238925
  0.5053858  0.49650666 0.5052731  0.4986546  0.5006881  0.5075788
  0.49783307 0.5050796  0.4949454  0.50169545 0.499763   0.49433225
  0.5129779  0.5053514  0.5068786 ]
 [0.49709252 0.49281713 0.48931667 0.5022155  0.49631658 0.49923655
  0.4953828  0.49465007 0.50665945 0.497

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


True data used:  ['AAAAA', 'AAAII', 'AAAHH', 'AAAFF', 'AAARR', 'AAAGG', 'AAAYY', 'AAAKK', 'AAAQQ', 'AAAPP']
Epoch:  20
generated data:
[[0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]
 ...
 [0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]]
peptide of noisedata[0]
_WR_L
peptide of generated data[0]
D_RR_
sequence:
[[0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 1. 0.]
 [1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1.]
 [0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1. 0. 1. 1. 1. 1. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
 [1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 1.]]
peptide of sequence
D_RR_
Generator Loss : 27.631023406982422
Discriminator Loss:  0.0
True data used:  ['AAAAA', 'AAAII', 'AAAHH', 'AAAFF', 'AAARR', 'AAAGG', 'AAAYY', 'AAAKK', 'AAAQQ', 'AAAPP']
Epoch:  40
generated data:
[[0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]
 ...
 

In [98]:
epochs = 1000
max_data= len(seq)
batch_size = 2
print_every = 100
loss_g = []
loss_d = []
input_length = MAX_PEPTIDE_LENGTH*NUM_AMINO_ACIDS
output_length = MAX_PEPTIDE_LENGTH*NUM_AMINO_ACIDS

# Models
generator = Generator(input_length,output_length)
discriminator = Discriminator(input_length, [64, 32, 1])  

# Optimizers
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.1)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.1)

# loss
loss = nn.BCELoss()



In [99]:
for i in range(epochs):

    for j in range(int(max_data/batch_size)):
        # zero the gradients on each iteration
        generator_optimizer.zero_grad()


        # Create noisy input for generator
        # Need float type instead of int
        noise = noise_function(batch_size)
        noise_data = torch.tensor(noise).float()
        generated_data = generator(noise_data)
        #print("Generated data in loop")
        #print(generated_data)

        # Generate examples of even real data
        true_data = data_function(seq,batch_size, j)
        """if i % print_every ==0:
            print("True data used: ", list(map(vector_to_peptide,true_data)) )"""
        true_labels = torch.tensor(np.ones(batch_size)).float()
        true_data = torch.tensor(true_data).float()

        # Train the generator
        # We invert the labels here and don't train the discriminator because we want the generator
        # to make things the discriminator classifies as true.
        generator_discriminator_out = discriminator(generated_data)
        generator_loss = loss(generator_discriminator_out, true_labels)
        generator_loss.backward()
        generator_optimizer.step()

        # Train the discriminator on the true/generated data
        discriminator_optimizer.zero_grad()
        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)

        # add .detach() here think about this
        generator_discriminator_out = discriminator(generated_data.detach())
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
        discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2   
        discriminator_loss.backward()
        discriminator_optimizer.step()

    generated_data = generated_data.detach().numpy()
    """if i % print_every == 0:
        print("Epoch: ",i,)
        print("generated data:")
        print(generated_data)
        print("peptide of noisedata[0]")
        print(vector_to_peptide(noise_data.numpy()[0]))
        print("peptide of generated data[0]")
        print(vector_to_peptide(generated_data[0]))"""
    
    """This threshold only makes sense if there's a sigmoid activation for the Generator"""
    #generated_data[generated_data > 0.5] = 1
    #generated_data[generated_data <= 0.5] = 0           
    sequence = np.reshape(generated_data[0], [MAX_PEPTIDE_LENGTH,NUM_AMINO_ACIDS])

    if i % print_every == 0:
        print("sequence:")
        print(sequence)
        print("peptide of sequence")
        print(vector_to_peptide(sequence))
        print("Generator Loss :",generator_loss.item())
        print("Discriminator Loss: ", discriminator_loss.item())

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


sequence:
[[0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0.]
 [0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0.]
 [1. 1. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1.]
 [1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 1. 0.]]
peptide of sequence
LRE__
Generator Loss : 38.396087646484375
Discriminator Loss:  0.06820950657129288
sequence:
[[0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0.]
 [0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0.]
 [1. 1. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1.]
 [1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 1. 0.]]
peptide of sequence
LRE__
Generator Loss : 27.63102149963379
Discriminator Loss:  0.0
sequence:
[[0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0.]
 [0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 

In [71]:
true_data = data_function(seq,2,0)
true_data = torch.tensor(true_data).float()
true_data

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [209]:
discriminator(true_data)

tensor([[1.],
        [1.]], grad_fn=<SigmoidBackward>)

In [100]:
fake_data = noise_function(5)
fake_data = torch.tensor(fake_data).float()
fake_data

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.3637,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2313,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1523,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.1494,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.3479,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  

In [None]:



#####################################
# unset BREAK_EARLY if we must not stop training earlier than max_epochs

BREAK_EARLY = True
max_epochs = 100
BLANK_CHAR = '~'


#####################################

##################################
# argv
# either pass 0 args or 1 args
if (len(sys.argv)) == 1:
    Arg_dataset_name = "none"
    Arg_num_layers = 1
    Arg_seqmaxlen = 22
else:
    _, Arg_dataset_name, Arg_num_layers, Arg_seqmaxlen = sys.argv
    Arg_num_layers = int(Arg_num_layers)
    Arg_seqmaxlen = int(Arg_seqmaxlen)

Imagesuffix = Arg_dataset_name + ".L" + str(Arg_num_layers) + ".s" + str(Arg_seqmaxlen)
##################################

if Arg_dataset_name == 'coffee':
    learning_rate = 0.1
elif Arg_dataset_name == 'emails':
    learning_rate = 0.01
elif Arg_dataset_name == 'tls':
    learning_rate = 0.02
##################################


Train_end_reason = "Max_epoch done"

#inp_alphabet = ".ab@_yz"
#out_alphabet = ".FT"

def get_alphabet(d_set):
    letters = set()
    for word in d_set:
        letters.update(set(word))

    return list([ BLANK_CHAR ]) + sorted(list(letters))


def simple_encode_strlist(X, max, alphabet):
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    #example: {'.': 0, 'a': 1, 'b': 2}
    X = [ [char_to_int[char] for char in x1.strip()] for x1 in X]
    X1 = [ x + [0] * (max - len(x)) for x in X]      # pad to max
    #example: simple_encode_inp(["aa","bb"],3) gives [[1, 1, 0], [2, 2, 0]]
    return X1

def simple_decode_idxlist1(enc_data, alphabet):
    idx2char = alphabet
    out = [idx2char[int(x)] for x in enc_data]
    out = ''.join(out)
    return out

def simple_decode_idxlist(d_arr, alphabet):
    return [simple_decode_idxlist1(d, alphabet) for d in d_arr]

def load_data_file(filename):
    data = pandas.read_csv(filename, header=None)
    print ("=== loaded data shape: ", data.shape)
    print ("=== data ", data)

    X_pandas = data[0]
    X_pandas = [ x.strip() for x in X_pandas]
    X_max = max([ len(x) for x in X_pandas])

    Y_pandas = data[1]
    Y_pandas = [ y.strip() for y in Y_pandas]
    Y_max = max([ len(y) for y in Y_pandas])

    if X_max != Y_max:
        print("max len mismatch in X and Y")
        sys.exit(1)

    seq_max = X_max
    return X_pandas, Y_pandas, seq_max

#################################


def ohe_singleletter(val, max):
    letter = [0 for _ in range(max)]
    letter[val] = 1
    return letter


def simple_to_onehot(D, alphabet):
    itemlist = list()
    for d in D:
        item = list()
        for e in d:
            l = ohe_singleletter(e, len(alphabet))
            item.append(l)
            #print(e, l)
        #print(d, "------",  item)
        itemlist.append(item)
    return itemlist

def onehot_decode_to_simple1(d):
    arr = np.array(d)
    idx = arr.argmax(1)
    return idx

def onehot_decode_to_simple(D):
    arr = np.array(D)
    idx = arr.argmax(2)
    return idx


def decode_inp(input):
    input = onehot_decode_to_simple1(input)
    input = simple_decode_idxlist1(input, Inp_alphabet)
    return input

def decode_out(output):
    output = simple_decode_idxlist1(output, Out_alphabet)
    return output

#################################

X_data, Y_data, train_Seq_max = load_data_file('rnn3-train-data.txt')
Num_io_data = len(X_data)

# set Seq_max to  max in both test and train
_, _, test_Seq_max = load_data_file('rnn3-test-data.txt')
Seq_max = max([train_Seq_max, test_Seq_max])

# if Seq_max <= Arg_seqmaxlen:
#     print("=== updating Seq_max from", Seq_max, " to ", Arg_seqmaxlen)
#     Seq_max = Arg_seqmaxlen
# else:
#     print("=== seq max len mismatch", Seq_max, Arg_seqmaxlen)
#     sys.exit(1)

if Arg_dataset_name == 'tls':
    Inp_alphabet =  [BLANK_CHAR, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
    Out_alphabet =  [ 'E', 'S', 'X', 'H', 'U', 'F', 'W', 'I', 'A', 'C' ]
    Out_alphabet =  list([ BLANK_CHAR ]) + sorted(Out_alphabet)

else:
    Inp_alphabet = get_alphabet(X_data)
    Out_alphabet = get_alphabet(Y_data)
#print(''.join(Inp_alphabet))
#print(''.join(Out_alphabet))
print("=== inp_alphabet:", Inp_alphabet)
print("=== out_alphabet:", Out_alphabet)

X_data = simple_encode_strlist(X_data, Seq_max, Inp_alphabet)
Y_data = simple_encode_strlist(Y_data, Seq_max, Out_alphabet)

#print(X_data)

Num_samples = len(X_data)

#print(simple_decode_idxlist1(X_data[0], Inp_alphabet))

#print(simple_decode_idxlist(X_data, Inp_alphabet))
#print(simple_decode_idxlist(Y_data, Out_alphabet))

X_ohe = simple_to_onehot(X_data, Inp_alphabet)
#print(X_ohe)

#print(onehot_decode_to_simple(X_ohe))
#print("------------")

#################################
class Model_RNN(nn.Module):
    def __init__(self):
        super(Model_RNN, self).__init__()
        self.rnn = nn.RNN(
                input_size = input_size,
                hidden_size = hidden_size,
                num_layers = num_layers,
                batch_first = True)
        self.Rnn_type = 'RNN'

    def forward(self, x, hidden):
        # Reshape input in (bs, seqlen, inpsz)
        x = x.view(batch_size, seq_len, input_size)

        # Propagete inp thru RNN
        #   Input: (batchsz, seq_len, inpsz)
        out, hidden = self.rnn(x, hidden)

        # Output: (batchsz, seq_len, hiddensz)
        out = out.view(-1, num_classes)
        return hidden, out

    def init_hidden(self):
        # Init hidden and cell states
        # (num_layers * num_dir, batch, hidden_sz)
        return Variable(torch.zeros(num_layers, batch_size, hidden_size))

class Model_LSTM(nn.Module):
    def __init__(self):
        super(Model_LSTM, self).__init__()
        self.rnn = nn.LSTM(
                input_size = input_size,
                hidden_size = hidden_size,
                num_layers = num_layers,
                batch_first = True)
        self.Rnn_type = 'LSTM'

    def forward(self, x, hidden):
        # Reshape input in (bs, seqlen, inpsz)
        x = x.view(batch_size, seq_len, input_size)

        # Propagete inp thru RNN
        #   Input: (batchsz, seq_len, inpsz)
        out, hidden = self.rnn(x, hidden)

        # Output: (batchsz, seq_len, hiddensz)
        out = out.view(-1, num_classes)
        return hidden, out

    def init_hidden(self):
        # Init hidden and cell states
        # (num_layers * num_dir, batch, hidden_sz)
        # LSTM hidden state is a tuple (h_0, c_0)
        h_0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))
        c_0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))
        hidden = (h_0, c_0)
        return hidden

#################################

Inp_size = len(Inp_alphabet)
Out_size = len(Out_alphabet)
torch.manual_seed(2.7321)

#################################
num_classes = Out_size          #  XXX why need a separate num_classes,  when hidden_size would do?
input_size = Inp_size  # this parameter is for the view function to know how large
#the one hot vector is supposed to be

hidden_size = Out_size #
batch_size = 1   # id dont understand this
seq_len = Seq_max
num_layers = Arg_num_layers  # num-layers of rnn
#learning_rate = 0.01
momentum = 0.1



vizdelay = 30
################################

# function to reduce learning rate based on accuracy.
# it returns a new optimizer
def get_new_optimizer(learning_rate, accuracy):
    new_learning_rate = learning_rate

    if accuracy >= 50:
        new_learning_rate = learning_rate * 0.9
    if accuracy >= 60:
        new_learning_rate = learning_rate * 0.8
    if accuracy >= 70:
        new_learning_rate = learning_rate * 0.7
    if accuracy >= 80:
        new_learning_rate = learning_rate * 0.5
    if accuracy >= 90:
        new_learning_rate = learning_rate * 0.4
    if accuracy >= 95:
        new_learning_rate = learning_rate * 0.2
    if accuracy >= 98:
        new_learning_rate = learning_rate * 0.1

    optimizer = torch.optim.Adam(model.parameters(), lr=new_learning_rate)
    return optimizer

############## Model Training code ##########################
def train_phase():
    global Optimizer
    MIN_LOSS = float('inf')
    MIN_LOSS_epoch_counter = 0
    MAX_ACC_epoch_counter = 0

    final_train_epoch=0
    final_train_accuracy=0

    for epoch in range(max_epochs):
        Optimizer.zero_grad()
        loss = 0
        errcount = 0
        # if epoch % 30 ==0:
        #     evaluate_model(inputs,labels)
        for input, label in zip(Inputs, Labels):
            # input = input.unsqueeze(0)
            hidden = model.init_hidden() #we reset the RNN to its initial state
            hidden, output = model(input, hidden) #run the model

            val, idx = output.max(1)
            expected = decode_out(label)
            trained = decode_out(idx)

            if trained != expected:
                errcount += 1

            if epoch % vizdelay == 0:
                if (trained != expected):
                    print("check : ", decode_inp(input), " -> expected: ", expected, "        predicted: ", trained, "   *****")
                else:
                    print("check : ", decode_inp(input), " -> expected: ", expected, "        predicted: ", trained)
            loss += Criterion(output, label) #add the current sample error to loss

        loss.backward()
        Optimizer.step()

        ########### Within loop plotting and logging ########
        if loss.data < MIN_LOSS:
            MIN_LOSS = loss.data
            MIN_LOSS_epoch_counter = epoch   # reset

        accuracy = 100.0 * (Num_io_data - errcount) / Num_io_data
        print("Epoch: %d, loss: %1.3f         errcount: %d  accuracy: %1.1f%%\n" % (epoch+1, loss.data, errcount, accuracy))
        print("-------------------------------------------")

        plot_data.append([loss.data.item(), accuracy])
        t = [ x for x in range(0, epoch)]
        p = np.array(plot_data)
        plt_ax1.plot(p[:, 0], color='red')
        plt_ax2.plot(p[:, 1], color='blue')
        plt.pause(0.001)
        #plt.show()

        final_train_epoch = epoch
        final_train_accuracy = accuracy
        if BREAK_EARLY == True:
            global Train_end_reason
            if (epoch - MIN_LOSS_epoch_counter) >= 20:
                Train_end_reason = "=== no new low of training_loss seen for last 20 epochs; stop training"
                print(Train_end_reason)
                break

            if accuracy >= 99.9:
                MAX_ACC_epoch_counter += 1
                # stop, if accuracy stays at ~100 for 10 epochs
                if MAX_ACC_epoch_counter >= 10:
                    Train_end_reason = "=== maximal accuracy seen for last 10 epochs; stop training"
                    print(Train_end_reason)
                    time.sleep(4)
                    break

        # update optimizer with changed lr - depending on accuracy
        Optimizer = get_new_optimizer(learning_rate, accuracy)
    return final_train_accuracy, final_train_epoch


####################### Model evaluation code ################
def evaluate_model_1(xdata, ydata):
    x_one_hot = simple_to_onehot(xdata, Inp_alphabet)
    inputs = Variable(torch.Tensor(x_one_hot))
    labels = Variable(torch.LongTensor(ydata))

    err_count = 0
    for input, label in zip(inputs, labels):
        expected = decode_out(label)
        hidden = model.init_hidden()
        hidden, output = model(input, hidden)
        val, idx = output.max(1)
        predicted = decode_out(idx)
        if (expected != predicted):
            print("check : ", decode_inp(input), " -> expected: ", expected, "        predicted: ", predicted, "   *****")
            err_count += 1
        else:
            print("check : ", decode_inp(input), " -> expected: ", expected, "        predicted: ", predicted)
    return (err_count)

def evaluate_model(xdata, ydata, seqmax):
    xdata = simple_encode_strlist(xdata, seqmax, Inp_alphabet)
    ydata = simple_encode_strlist(ydata, seqmax, Out_alphabet)
    return evaluate_model_1(xdata, ydata)

################################################################################
def validation_phase():
    print("============================= validation inputs ===========================")
    validation_inputs, validation_outputs, seqmax = load_data_file('rnn3-train-data.txt')
    num_io_data = len(validation_inputs)
    if seqmax != Seq_max:
        print("seqmax mismatch", seqmax, Seq_max)
        sys.exit(1)
    seqmax = Seq_max

    errcount = evaluate_model(validation_inputs, validation_outputs, seqmax)
    val_accuracy = 100.0 * (num_io_data - errcount) / num_io_data

    if errcount > 0:
        print("VALIDATION FAILED:  errors: ", errcount, "accuracy:", val_accuracy)
    else:
        print("Validation Passed")
    return val_accuracy


################################################################################
def test_phase():
    print("============================= test inputs ===========================")
    test_inputs, test_outputs, seqmax = load_data_file('rnn3-test-data.txt')
    num_io_data = len(test_inputs)
    #if seqmax != Seq_max: #this could happen just due to randomness #todo
    #    print("WARN: seqmax mismatch")     ## possible to have diff seq max from the train data
    #    sys.exit(1)
    seqmax = Seq_max


    errcount = evaluate_model(test_inputs, test_outputs, seqmax)
    test_accuracy = 100.0 * (num_io_data - errcount) / num_io_data

    print("Testset size: ", len(test_inputs),"Test errors: ", errcount, "accuracy:", test_accuracy)
    return test_accuracy

################################################################################

########################################################
#Instantiate RNN model
#model = Model_RNN()
Rnn_type = '###'
model = Model_LSTM()
Rnn_type = model.Rnn_type

####################### Plotting code ##################

plt.ion()
plt_fig, plt_ax1 = plt.subplots()
plt_ax2 = plt_ax1.twinx()
plt_ax1.set_xlabel('epoch')
plt_ax1.set_ylabel('loss', color='red')
plt_ax2.set_ylabel('accuracy', color='blue')

txt = "Dataset: " + Arg_dataset_name + "\nrnn_type: "+ Rnn_type + "\nnum_layers: "+ str(num_layers) + "\nnum_io_samples: "+ str(Num_io_data) + "\nseq_max_len: "+ str(Seq_max)
txt += "\nBase_LR: " + str(learning_rate)

#plt_fig = plt.figure()
plt_fig.text(.5, .2, txt, ha='center', transform=plt_ax1.transAxes)
plt.pause(0.001)

plot_data = []
########################################################



################## Load trained model ###############
from pathlib import Path
modelimagefile = "rnn3model." + Imagesuffix + ".pt"
if Path(modelimagefile).is_file():
    # file exists
    model.load_state_dict(torch.load(modelimagefile))
    model.eval()
    print("=== Model was loaded from " + modelimagefile)

####################################################
Criterion = torch.nn.CrossEntropyLoss()
Optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
####################################################



# X is OH encoded
# Y is simple encoded
# this is the format needed by pytorch
Inputs = Variable(torch.Tensor(X_ohe))
Labels = Variable(torch.LongTensor(Y_data))

print("Input shape: ", Inputs.size())
print("Output shape: ", Labels.size())

final_train_accuracy, final_train_epoch = train_phase()

############ final plotting and logging ################
#plt.plot(plot_data)
#plt.waitforbuttonpress()
plt.savefig('rnn3-ttt.png',  bbox_inches='tight')
plt.show()
########################################################

val_accuracy = validation_phase()
test_accuracy = test_phase()

print("---------------------------------------------------------------------")
print("=== rnn_type:", Rnn_type,
        "num_layers:", num_layers,
        "num_io_samples (train): ", Num_io_data,
        "seq_max_len:", Seq_max)

print("\n Training end due to: ", Train_end_reason)
print("\n=== final_train_epoch:", final_train_epoch)
print("=== final_train_accuracy:  %1.2f%%       \n=== val_accuracy:  %1.2f%%        \n=== test_accuracy:  %1.2f%%\n"
        %  (final_train_accuracy, val_accuracy, test_accuracy))


######################### save trained model #######################
torch.save(model.state_dict(), modelimagefile)
print("=== Model was saved as " + modelimagefile)
####################################################################