<a href="https://colab.research.google.com/github/mahn-bonnie/Generative-AI-Series/blob/main/FNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Text Generation using Fnet**
#####Transformer-based models excel in understanding and processing sequences due to their utilization of a mechanism known as “self-attention.” This involves scrutinizing each token to discern its relationship with every other token in the sequence.

**Step 1: Libraries and import**

In [1]:
#Install below libraries if they are not available in your environment

!pip install datasets
!pip install torch[transformers]

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2

In [2]:
#Declare device variable for computation on GPU if available

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


cuda


**Step 2 : Load Data**

In [3]:
from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

**Step 3: Data Preprocessing**

In [4]:
import re


def preprocess_text(sentence):
	# lowering the sentence and storing in text vaiable
	text = sentence['text'].lower()
	# removing other than characters and punctuations
	text = re.sub('[^a-z?!.,]', ' ', text)
	text = re.sub('\s\s+', ' ', text) # removing double spaces
	sentence['text'] = text
	return sentence


datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)

datasets['train'] = datasets['train'].filter(lambda x: len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x: len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(
	lambda x: len(x['text']) > 20)



Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

**Step 4. : Tokenisation**

In [5]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


# Tokenizer
def tokenize(sentence):
	sentence = tokenizer(sentence['text'], truncation=True)
	return sentence


tokenized_inputs = datasets['test'].map(tokenize)
tokenized_inputs = tokenized_inputs.remove_columns(['text'])


# DataCollator
batch = 16
data_collator = DataCollatorWithPadding(
	tokenizer=tokenizer, padding=True, return_tensors="pt")
dataloader = DataLoader(
	tokenized_inputs, batch_size=batch, collate_fn=data_collator)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Map:   0%|          | 0/2312 [00:00<?, ? examples/s]

**Step 5 : Embedding Positional encoding**

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft
import numpy as np
import pandas as pd

class PositionalEncoding(torch.nn.Module):


	def __init__(self, d_model, max_sequence_length):
		super().__init__()
		self.d_model = d_model
		self.max_sequence_length = max_sequence_length
		self.positional_encoding = self.create_positional_encoding().to(device)

	def create_positional_encoding(self):

		# Initialize positional encoding matrix
		positional_encoding = np.zeros((self.max_sequence_length, self.d_model))

		# Calculate positional encoding for each position and each dimension
		for pos in range(self.max_sequence_length):
			for i in range(0, self.d_model, 2):
				# Apply sin to even indices in the array; indices in Python start at 0 so i is even.
				positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / self.d_model)))

				if i + 1 < self.d_model:
					# Apply cos to odd indices in the array; we add 1 to i because indices in Python start at 0.
					positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * i) / self.d_model)))

		# Convert numpy array to PyTorch tensor and return it
		return torch.from_numpy(positional_encoding).float()

	def forward(self, x):
		expanded_tensor = torch.unsqueeze(self.positional_encoding, 0).expand(x.size(0), -1, -1).to(device)

		return x.to(device) + expanded_tensor[:,:x.size(1), :]

class PositionalEmbedding(nn.Module):
  def __init__(self, sequence_length, vocab_size, embed_dim):
    super(PositionalEmbedding, self).__init__()
    self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.position_embeddings = PositionalEncoding(embed_dim,sequence_length)

  def forward(self, inputs):
    embedded_tokens = self.token_embeddings(inputs).to(device)
    embedded_positions = self.position_embeddings(embedded_tokens).to(device)
    return embedded_positions.to(device)


**Step 6 : Create FNet Encoder**

In [7]:
class FNetEncoder(nn.Module):

  def __init__(self,embed_dim, dense_dim):
    super(FNetEncoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.dense_proj = nn.Sequential(nn.Linear(self.embed_dim,self.dense_dim), nn.ReLU(), nn.Linear(self.dense_dim,self.embed_dim))

    self.layernorm_1 = nn.LayerNorm(self.embed_dim)
    self.layernorm_2 = nn.LayerNorm(self.embed_dim)

  def forward(self,inputs):

    fft_result = fft.fft2(inputs)

    #taking real part
    fft_real = fft_result.real.float()

    proj_input = self.layernorm_1 (inputs + fft_real)
    proj_output = self.dense_proj(proj_input)
    return self.layernorm_2(proj_input +proj_output)


**Step 7 : Create FnetDecoder**

In [8]:
class FNetDecoder(nn.Module):

  def __init__(self,embed_dim,dense_dim,num_heads):
    super(FNetDecoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads

    self.attention_1 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
    self.attention_2 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)

    self.dense_proj = nn.Sequential(nn.Linear(embed_dim, dense_dim),nn.ReLU(),nn.Linear(dense_dim, embed_dim))

    self.layernorm_1 = nn.LayerNorm(embed_dim)
    self.layernorm_2 = nn.LayerNorm(embed_dim)
    self.layernorm_3 = nn.LayerNorm(embed_dim)

  def forward(self, inputs, encoder_outputs, mask=None):
    causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(1)).to(device)

    attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=causal_mask)
    out_1 = self.layernorm_1(inputs + attention_output_1)

    if mask != None:
      attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, key_padding_mask =torch.transpose(mask, 0, 1).to(device))
    else:
        attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs)
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)


**Step 8 : Fnet Model**

In [9]:
class FNetModel(nn.Module):
	def __init__(self, max_length, vocab_size, embed_dim, latent_dim, num_heads):
		super(FNetModel, self).__init__()

		self.encoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
		self.encoder1 = FNetEncoder(embed_dim, latent_dim)
		self.encoder2 = FNetEncoder(embed_dim, latent_dim)
		self.encoder3 = FNetEncoder(embed_dim, latent_dim)
		self.encoder4 = FNetEncoder(embed_dim, latent_dim)


		self.decoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
		self.decoder1 = FNetDecoder(embed_dim, latent_dim, num_heads)
		self.decoder2 = FNetDecoder(embed_dim, latent_dim, num_heads)
		self.decoder3 = FNetDecoder(embed_dim, latent_dim, num_heads)
		self.decoder4 = FNetDecoder(embed_dim, latent_dim, num_heads)


		self.dropout = nn.Dropout(0.5)
		self.dense = nn.Linear(embed_dim, vocab_size)

	def encoder(self,encoder_inputs):
		x_encoder = self.encoder_inputs(encoder_inputs)
		x_encoder = self.encoder1(x_encoder)
		x_encoder = self.encoder2(x_encoder)
		x_encoder = self.encoder3(x_encoder)
		x_encoder = self.encoder4(x_encoder)
		return x_encoder

	def decoder(self,decoder_inputs,encoder_output,att_mask):
		x_decoder = self.decoder_inputs(decoder_inputs)
		x_decoder = self.decoder1(x_decoder, encoder_output,att_mask) ## HERE for inference
		x_decoder = self.decoder2(x_decoder, encoder_output,att_mask) ## HERE for inference
		x_decoder = self.decoder3(x_decoder, encoder_output,att_mask) ## HERE for inference
		x_decoder = self.decoder4(x_decoder, encoder_output,att_mask) ## HERE for inference
		decoder_outputs = self.dense(x_decoder)

		return decoder_outputs

	def forward(self, encoder_inputs, decoder_inputs,att_mask = None):
		encoder_output = self.encoder(encoder_inputs)
		decoder_output = self.decoder(decoder_inputs,encoder_output,att_mask=None)
		return decoder_output


**Step 9 : Initialize Model**

We declare hyperparameters and initialize our model

In [10]:
# Assuming your constants are defined like this:
MAX_LENGTH = 512
VOCAB_SIZE = len(tokenizer.vocab)
EMBED_DIM = 256
LATENT_DIM = 100
NUM_HEADS = 4

# Create an instance of the model
fnet_model = FNetModel(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM, LATENT_DIM, NUM_HEADS).to(device)


**Step 10 : Train the model**

We declare our optimizer and loss function.
An Adam optimizer is defined for updating the parameters of the model during training.
CrossEntropyLoss is chosen as the loss function.
We then train our model for 10 epochs

In [12]:
import torch
import torch.nn as nn

# Define your optimizer and loss function
optimizer = torch.optim.Adam(fnet_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=-100)  # Set ignore_index to -100

epochs = 20
for epoch in range(epochs):
    train_loss = 0
    for batch in dataloader:
        encoder_inputs_tensor = batch['input_ids'][:, :-1].to(device)
        decoder_inputs_tensor = batch['input_ids'][:, 1:].to(device)

        att_mask = batch['attention_mask'][:, :-1].to(device).to(dtype=torch.bool)
        optimizer.zero_grad()

        # Forward pass
        outputs = fnet_model(encoder_inputs_tensor, decoder_inputs_tensor, att_mask)

        # Create the target tensor with padding value
        targets = decoder_inputs_tensor.clone()
        targets[batch['attention_mask'][:, 1:] == 0] = -100  # Assuming 0 is the padding value

        # Calculate loss
        loss = criterion(outputs.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
        train_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"Epoch: {epoch}, Train Loss: {train_loss}")


Epoch: 0, Train Loss: 9.368071470351424
Epoch: 1, Train Loss: 0.47274261838174425
Epoch: 2, Train Loss: 0.08393846603576094
Epoch: 3, Train Loss: 0.05510656790283974
Epoch: 4, Train Loss: 0.04125464403841761
Epoch: 5, Train Loss: 0.032524852103961166
Epoch: 6, Train Loss: 0.02627902630956669
Epoch: 7, Train Loss: 0.021591453416476725
Epoch: 8, Train Loss: 0.01803341192226071
Epoch: 9, Train Loss: 0.015260288758327079
Epoch: 10, Train Loss: 0.013070145338133443
Epoch: 11, Train Loss: 0.011313059420899663
Epoch: 12, Train Loss: 0.009869489724678715
Epoch: 13, Train Loss: 0.008661623970965593
Epoch: 14, Train Loss: 0.007641943623184488
Epoch: 15, Train Loss: 0.006776599097975122
Epoch: 16, Train Loss: 0.006037613589569446
Epoch: 17, Train Loss: 0.00540227940382465
Epoch: 18, Train Loss: 0.004852015348205896
Epoch: 19, Train Loss: 0.004372161977016731


**Step 11 : Use model for text generation**

In [23]:
import re

def preprocess_text(sentence):
    # Lowering the sentence and storing in the text variable
    text = sentence.lower()

    # Removing characters other than alphabets and punctuation
    text = re.sub('[^a-z?!.,]', ' ', text)

    # Optionally, you can remove extra spaces
    text = re.sub(' +', ' ', text).strip()

    return text

def decode_sentence(input_sentence, fnet_model):
    fnet_model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # Disable gradient calculation for inference
        # Tokenize the input sentence after preprocessing
        processed_text = preprocess_text(input_sentence)
        tokenized_input_sentence = torch.tensor(tokenizer(processed_text)['input_ids']).to(device)

        # Start with the [CLS] token for the target sentence
        tokenized_target_sentence = torch.tensor([101]).to(device)  # '[CLS]' token

        current_text = processed_text  # Start with the preprocessed input text

        for i in range(MAX_LENGTH):
            # Predict the next token
            predictions = fnet_model(tokenized_input_sentence[:-1].unsqueeze(0), tokenized_target_sentence.unsqueeze(0))
            predicted_index = torch.argmax(predictions[0, -1, :]).item()  # Get the index of the predicted token
            predicted_token = tokenizer.decode([predicted_index])  # Decode the token index to the actual token

            if predicted_token == "[SEP]":  # Stop if the end token is generated
                break

            current_text += " " + predicted_token  # Append the predicted token to the current text

            # Update the target sentence with the predicted token
            tokenized_target_sentence = torch.cat([tokenized_target_sentence, torch.tensor([predicted_index]).to(device)], 0)
            tokenized_input_sentence = torch.tensor(tokenizer(current_text)['input_ids']).to(device)  # Update the input sentence tokens

        return current_text

# Example usage
output_sentence = decode_sentence('How are you ?', fnet_model)
print(output_sentence)


how are you ? tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournament tournam