In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import fa_support
import evall
import random

In [None]:
class WGAN(tf.keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        discriminator_extra_steps=5,
        batch_size=100
    ):
        super(WGAN, self).__init__()
        self.discriminator = d_model
        self.generator = g_model
        self.d_steps = discriminator_extra_steps
        self.batch_size = batch_size
        self.k = 10        
        self.index = 0 
        self.c_index = 0 
        self.gp_weight = 10
        self.eval_steps = 0
        self.max_p10 = .01 
        self.max_g10 = .01
        self.max_m10 = .01
        self.max_p20 = .01
        self.max_g20 = .01
        self.max_m20 = .01
        self.val_losses = []
    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn,c_loss_fn, run_eagerly):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.c_loss_fn = c_loss_fn
        self.g_loss_fn = g_loss_fn
        self.run_eagerly = run_eagerly



    def gradient_penalty(self, batch_size, real_users, fake_users,  brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size,1], 0.0, 1.0)
        diff = fake_users - real_users
        interpolated = real_users + alpha * diff
        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            interpolated_input = discriminator_input(   brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup, interpolated)
            pred = self.discriminator(interpolated_input)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = tape.gradient(pred, [interpolated])[0] +1e-10
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1))
        gp = tf.reduce_mean((norm - 1.0) ** 2) +1e-5
        return gp

    def gradient_penalty_val(self, batch_size, real_users, fake_users,  brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size,1], 0.0, 1.0)
        diff = fake_users - real_users
        interpolated = real_users + alpha * diff
        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            interpolated_input = discriminator_input_new(   brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup, interpolated)
            pred = self.discriminator(interpolated_input)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = tape.gradient(pred, [interpolated])[0] +1e-10
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1))
        gp = tf.reduce_mean((norm - 1.0) ** 2) +1e-5
        return gp

    def train_step(self, real_users):
        self.eval_steps +=1 
        c_batch_size = 2*self.batch_size
        for i in range(self.d_steps):

            with tf.GradientTape() as tape:
                # Get batch data
                country, postcode, loyal, gender,  brand_id, category, colour, divisioncode, \
                itemcategorycode, itemfamilycode, itemseason, productgroup, real_users, items = fa_support.get_batchdata(self.index, self.index + self.batch_size)
                # Get batch of counter examples
                counter_brand_id, counter_category, counter_colour, counter_divisioncode, \
                counter_itemcategorycode, counter_itemfamilycode, counter_itemseason, \
                counter_productgroup,  counter_users = fa_support.get_counter_batch(self.c_index, self.c_index + c_batch_size)
                # Generate fake users from attributes
                g_input0 = generator_input(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup)
                fake_users = self.generator(g_input0)
                # Get the logits for the fake users
                d_input0 = discriminator_input( brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, fake_users)
                fake_logits = self.discriminator(d_input0)
                # Get the logits for the real user
                d_input1 = discriminator_input(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, real_users)
                real_logits = self.discriminator(d_input1)
                # Get logits for counter examples
                
                d_input2 = discriminator_input( counter_brand_id, counter_category, counter_colour, counter_divisioncode, \
                counter_itemcategorycode, counter_itemfamilycode, counter_itemseason, \
                counter_productgroup,  counter_users)
                counter_logits = self.discriminator(d_input2)
                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_logits, fake_logits)
                c_loss = self.c_loss_fn(counter_logits)
                # Get gradient penalty
                gp = self.gradient_penalty(self.batch_size, real_users, fake_users, brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup)
                # Later add counter loss
                d_loss = d_cost + c_loss +1e-5  + gp*self.gp_weight 

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)

            
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))

        # Train the generator
        with tf.GradientTape() as tape:

            # Generate fake images using the generator
            g_input1 = generator_input(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup)
            gen_users = self.generator(g_input1)
            # Get the discriminator logits for fake images
            d_input2 = discriminator_input(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, gen_users)
            gen_logits = self.discriminator(d_input2)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_logits) +1e-5

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        
        self.index = self.index +self.batch_size
        
        if self.eval_steps %760==0:
          #  val_loss = wgan.val_step(10)
           # return {"training_loss": d_loss+g_loss, "val_loss":val_loss}
        #else:
        #return {"training_loss": d_loss}
            p_at_10,G_at_10,M_at_10, val_p10,val_g10,val_m10 = wgan.test(10)
            p_at_20,G_at_20,M_at_20, val_p20,val_g20,val_m20 = wgan.test(20)
            if p_at_10 > self.max_p10:
                self.max_p10 = p_at_10
            if G_at_10 > self.max_g10:
                self.max_g10 = G_at_10
            if M_at_10 > self.max_m10:
                self.max_m10 = M_at_10
            if p_at_20 > self.max_p20:
                self.max_p20 = p_at_20
            if G_at_20 > self.max_g20:
                self.max_g20 = G_at_20
            if M_at_20 > self.max_m20:
                self.max_m20 = M_at_20
            
            return {"training_loss": d_loss, "p10":p_at_10,
                           "G10":G_at_10,"M10":M_at_10, "p20": p_at_20,"G20":G_at_20,"M20":M_at_20,
                  "val_p10":val_p10,
                           "val_G10":val_g10,"val_M10":val_m10, "val_p20": val_p20,"val_G20":val_g20,"val_M20":val_m20}
        
        else:
            return {"training_loss": d_loss}
        
    def test(self, k):
        
        country, postcode, loyal, gender,  brand_id, category, colour, divisioncode, \
        itemcategorycode, itemfamilycode, itemseason, productgroup, item, user, user_emb = fa_support.get_testdata(0,5000)
        
        test_BATCH_SIZE = country.size
        g_input1 = generator_input( brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup)
        gen_users = self.generator(g_input1)
        sim_users = fa_support.get_intersection_similar_user(gen_users, k)

        count = 0
        for i, user_list in zip(item, sim_users):       
            for u in user_list:
                if ui_matrix[u,i] == 1:
                    count = count + 1            
        p_at_k = round(count/(test_BATCH_SIZE * k), 4)

        RS = []
        ans = 0.0
        for i, user_list in zip(item, sim_users):           
            r=[]
            for user in user_list:
                 r.append(ui_matrix[user][i])
            ans = ans + evall.ndcg_at_k(r, k, method=1)
            RS.append(r)
        G_at_k = ans/test_BATCH_SIZE
        M_at_k = evall.mean_average_precision(RS)
        
        country, postcode, loyal, gender,  brand_id, category, colour, divisioncode, \
        itemcategorycode, itemfamilycode, itemseason, productgroup, item, user, user_emb = fa_support.get_testdata(5000,10000)
        
        test_BATCH_SIZE = country.size
        g_input1 = generator_input( brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup)
        gen_users = self.generator(g_input1)
        sim_users = fa_support.get_intersection_similar_user(gen_users, k)

        count = 0
        for i, user_list in zip(item, sim_users):       
            for u in user_list:
                if ui_matrix[u,i] == 1:
                    count = count + 1            
        val_p_at_k = round(count/(test_BATCH_SIZE * k), 4)

        RS = []
        ans = 0.0
        for i, user_list in zip(item, sim_users):           
            r=[]
            for user in user_list:
                 r.append(ui_matrix[user][i])
            ans = ans + evall.ndcg_at_k(r, k, method=1)
            RS.append(r)
        val_G_at_k = ans/test_BATCH_SIZE
        val_M_at_k = evall.mean_average_precision(RS)
       
       

        return p_at_k,G_at_k,M_at_k, val_p_at_k, val_G_at_k, val_M_at_k
    
    def test_step(self,data):
        for i in range(self.d_steps):
            with tf.GradientTape() as tape:
                brand_id = data[0][:,0] 
                category = data[0][:,1]
                colour = data[0][:,2] 
                divisioncode = data[0][:,3] 
                itemcategorycode = data[0][:,4] 
                itemfamilycode = data[0][:,5] 
                itemseason = data[0][:,6] 
                productgroup = data[0][:,7] 
                user_emb = tf.cast(data[1],dtype=float)

                counter_brand_id, counter_category, counter_colour, counter_divisioncode, \
                counter_itemcategorycode, counter_itemfamilycode, counter_itemseason, \
                counter_productgroup,  counter_users = fa_support.get_val_counter_batch(0,10000)


                country, postcode, loyal, gender,  brand_id, category, colour, divisioncode, \
                itemcategorycode, itemfamilycode, itemseason, productgroup, item, user, user_emb = fa_support.get_testdata(5000,10000)



                g_input1 = generator_input( brand_id, category, colour, divisioncode, \
                                                   itemcategorycode, itemfamilycode, itemseason, productgroup)

                gen_users = self.generator(g_input1)
                d_input0 = discriminator_input( brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, gen_users)
                fake_logits = self.discriminator(d_input0)
                # Get the logits for the real user
                d_input1 = discriminator_input(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, user_emb)
                real_logits = self.discriminator(d_input1)
                d_input2 = discriminator_input( counter_brand_id, counter_category, counter_colour, counter_divisioncode, \
                    counter_itemcategorycode, counter_itemfamilycode, counter_itemseason, \
                    counter_productgroup,  counter_users)
                counter_logits = self.discriminator(d_input2)
                # Calculate the discriminator loss using the fake and real image logits
                c_loss = self.c_loss_fn(counter_logits)
                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_logits, fake_logits)
                # Get gradient penalty
                gp = self.gradient_penalty(5000, user_emb, gen_users, brand_id,\
                                    category, colour, divisioncode, itemcategorycode, itemfamilycode, \
                                    itemseason, productgroup)
                # Later add counter loss
                d_loss = d_cost  +c_loss +1e-5  + gp*self.gp_weight 

        # Train the generator
        with tf.GradientTape() as tape:

            # Generate fake images using the generator
            g_input1 = generator_input_new(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup)
            gen_users = self.generator(g_input1)
            # Get the discriminator logits for fake images
            d_input2 = discriminator_input_new(  brand_id, category, colour, divisioncode, \
                                           itemcategorycode, itemfamilycode, itemseason, productgroup, gen_users)
            gen_logits = self.discriminator(d_input2)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_logits) +1e-5
        
        val_loss =  d_loss
        return {"val_loss":val_loss}
    
    
    