- using peft from finetune generative models 
- using hugging face 'polite' dataset 
- using pytorch lightning from comfortable coding 

In [1]:
class Config:
    tok_params = {'return_tensors': 'pt', 
                  'padding': 'max_length', 
                  'truncation': True}
    batch_size = 16
    num_epoch = 1 
    
    def __init__(self):
#     Discriminator 
        self.dmodel_preset = 'distilbert/distilbert-base-uncased'
        self.dnum_labels = 10  # Adjust according to your classification task
#     Generator 
        self.gmodel_preset = 'describeai/gemini-small'
        self.generator_max_length = 256
#     Dataset
        self.dataset = 'jdustinwind/Polite'    
        self.rude, self.polite = ['src', 'tgt']
    
cfg = Config()

In [2]:
# ! python -m pip install -q lightning
! pip install -q datasets==2.18.0

import datasets as ds  
import transformers as tr

ds.__version__, tr.__version__

('2.18.0', '4.39.3')

In [3]:
# COMING SOON
# ..........
# # comfortable model training and coding 
# import lightning as L 
# ..........

# neural network 
import torch
from torch import nn 
import torch.optim as optim

from sklearn.model_selection import train_test_split

# work with data
from torch.utils.data import Dataset, DataLoader

# load and use models 
from transformers import AutoModelForSeq2SeqLM, AutoModel, AutoTokenizer
import datasets 

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() == True else 'cpu')

In [4]:
def tokenize_batch(rude, polite, gen_tokenizer, dics_tokenizer): 
    params = cfg.tok_params
    
    gen_tokens_rude = gen_tokenizer(rude, **params)
    disc_tokens_polite = disc_tokenizer(polite, **params)
    
    return {
        "rude_input_ids_g": gen_tokens_rude['input_ids'].squeeze(), 
        "rude_attention_mask_g": gen_tokens_rude['attention_mask'].squeeze(), 
        
        "d_polite_input_ids": disc_tokens_polite['input_ids'].squeeze(), 
        "d_polite_attention_mask": disc_tokens_polite['attention_mask'].squeeze(), 
    }


In [5]:
class Generator(nn.Module): 
    def __init__(self, text2text_preset): 
        super().__init__()
        
        self.gen = AutoModelForSeq2SeqLM.from_pretrained(text2text_preset)
        
    def forward(self, **input_batch): 
        return self.gen.generate(**input_batch, max_length = cfg.generator_max_length)
    
class Discriminator(nn.Module):
    def __init__(self, model_preset, dropout = 0.3): 
        super().__init__()
        
        self.backbone = AutoModel.from_pretrained(model_preset)

        self.dropout = nn.Dropout(dropout)
        self.lin = nn.Linear(self.backbone.config.dim, cfg.dnum_labels)
        self.softmax=  nn.Softmax(dim=-1)
        
    def forward(self, **input_batch):
        backbone_output = self.backbone(**input_batch).last_hidden_state[:, 0, :]
        
        drop_output = self.dropout(backbone_output)
        output = self.lin(drop_output)
        return self.softmax(output)

In [6]:
class GAN(): 
    def __init__(self, 
#                  DISCRIMINATOR
                 discriminator, 
                 discriminator_tokenizer, 
                 
                 
#                  GENERATOR
                 generator, 
                 generator_tokenizer, 
                 
                 discriminator_lr = 1e-4, 
                 generator_lr = 1e-4, 
                ): 
        
        self.d = discriminator.to(device)
        self.g = generator.to(device)
        
        self.dt = discriminator_tokenizer
        self.gt = generator_tokenizer
        
        self.do = optim.AdamW(discriminator.parameters(), lr=discriminator_lr)
        self.go = optim.AdamW(generator.parameters(), lr=generator_lr)
        
        self.criterion = nn.BCELoss()
        
    def train_step(self, batch): 
        batch = {k: v.to(device) for (k, v) in batch.items()}
        
        # -----------------
        #  Train Generator
        # -----------------
        
        self.go.zero_grad()
        
        polite_ids = self.g(input_ids = batch['rude_input_ids_g'], attention_mask = batch['rude_attention_mask_g'])

        polite_gen = self.gt.batch_decode(polite_ids)
        polite_gen = self.dt(polite_gen, **cfg.tok_params).to(device)
        
        dics_polite_gen = self.d(**polite_gen)
        
#         fake (0) -> real (1)
        g_loss = self.criterion(dics_polite_gen, torch.ones(dics_polite_gen.shape).to(device))
        g_loss.backward()
        self.go.step()
        
        
        # -----------------
        #  Train Discriminator
        # -----------------
        
        self.do.zero_grad() 
        
        dics_polite_fake = self.d(**polite_gen)
        dics_polite_real = self.d(input_ids = batch['d_polite_input_ids'], attention_mask = batch['d_polite_attention_mask'])
        
#         fake (...) -> 0
        fake_loss = self.criterion(dics_polite_fake, torch.zeros(dics_polite_fake.shape).to(device)) 
#         real (..) -> 1
        real_loss = self.criterion(dics_polite_real, torch.ones(dics_polite_real.shape).to(device))
        
        fake_loss.backward()
        real_loss.backward() 
        self.do.step() 
        
        return {'generator_loss': g_loss, 'discriminator_loss': fake_loss + real_loss}
#         calculate {disc loss (real) + dics loss (gen)}
#         gen_loss += (disc_polite_real - dics_polite_gen)

In [7]:
generator = Generator(cfg.gmodel_preset)
gen_tokenizer = AutoTokenizer.from_pretrained(cfg.gmodel_preset) 

discriminator = Discriminator(cfg.dmodel_preset)
disc_tokenizer = AutoTokenizer.from_pretrained(cfg.dmodel_preset) 

config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


tokenizer_config.json:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [8]:
gan = GAN(generator = generator, 
          generator_tokenizer=gen_tokenizer, 
          
          discriminator = discriminator, 
          discriminator_tokenizer = disc_tokenizer)

In [9]:
ds = datasets.load_dataset(cfg.dataset)
ds = ds['train'].to_pandas() 
ds = ds

ds_train, ds_test = train_test_split(ds, test_size = 0.1)

Downloading readme:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.
Downloading data: 100%|██████████| 12.6M/12.6M [00:01<00:00, 6.91MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

In [10]:
class MyDataset(Dataset): 
    def __init__(self, polites, rudes, generator_tokenizer, discriminator_tokenizer): 
        super().__init__()
        
        self.gt=  generator_tokenizer
        self.dt = discriminator_tokenizer 
        
        self.polites = polites
        self.rudes = rudes 
        
    def __getitem__(self, idx): 
        return tokenize_batch(self.rudes[idx], self.polites[idx], self.gt, self.dt)
    
    def __len__(self,): 
        return len(self.polites)
    
def build_tensor_ds(dataframe): 
    
    dataset = MyDataset(polites = dataframe[cfg.polite].to_list(), 
                        rudes = dataframe[cfg.rude].to_list(), 
        
                        generator_tokenizer = gen_tokenizer, 
                        discriminator_tokenizer = disc_tokenizer)

    return DataLoader(dataset, batch_size = cfg.batch_size, shuffle=True)
    
dl_train = build_tensor_ds(ds_train)
dl_valid = build_tensor_ds(ds_test)

In [11]:
generator.train()
discriminator.train() 

logs = list() 
for epoch in range(cfg.num_epoch): 
    print(f"{epoch}/{cfg.num_epoch}")
    
    for batch in tqdm(dl_train):
        log = gan.train_step(batch)
        logs.append(logs)

0/1


100%|██████████| 5625/5625 [4:03:36<00:00,  2.60s/it]


In [12]:
torch.save({
    'generator': generator.state_dict(), 
    'discriminator': discriminator.state_dict(), 
    'logs': logs}, 
    f = 'checkpoint'
)