# Get preprocessed 'DontPatronizeMe' dataset

In [None]:
!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1-gNxTZfDL0aOpzOnxE80M29dUVjSoozn' -O 'train.csv'
!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1-cSiEWP_NbDu7fo_7s8O5P163oKLQcBh' -O 'valid.csv'
!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1-13l35-18IYPFSV_36llsJbb7c4Gu2o0' -O 'test.csv'
!pip install transformers


# Import package, model, dataset

In [26]:
import pandas as pd
import copy
from transformers import (
    DistilBertTokenizer, DistilBertForMaskedLM, DistilBertConfig,
    BertTokenizer, BertModel as Bert,
    activations
)
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
import tqdm
import matplotlib.pyplot as plt

if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print("using device: ", dev)

using device:  cpu


In [5]:
# read pandas data
train_path = './train.csv'
valid_path = './valid.csv'
test_path = './test.csv'

train_df = pd.read_csv(train_path).dropna()
valid_df = pd.read_csv(valid_path).dropna()
test_df = pd.read_csv(test_path).dropna()

In [7]:
# download pretrained model and tokenizer
def save_model_tokenizer(tokenizer_class, model_class, name):
  tokenizer = tokenizer_class.from_pretrained(name)
  tokenizer.save_pretrained(f"./tokenizers/{name}-local")
  model = model_class.from_pretrained(name)
  model.save_pretrained(f"./models/{name}-local/")

save_model_tokenizer(DistilBertTokenizer, DistilBertForMaskedLM, "distilbert-base-uncased")
# save_model_tokenizer(BertTokenizer, Bert, "bert-base-cased")

# Hyperparameters

In [58]:
# hyperparameters
batch_size = 16
max_length = 128 # max text length
learning_rate = 5e-5
epoch_num = 2
linear_probe = False

# diffusion hyperparameter
beta_min = 0.0001
beta_max = 0.02
step_tot = 2000 # total noise adding steps
sample_size = 1 # number of sample steps in each diffuse sequence TODO: suspect can only be 1, otherwise mask is wrong
train_embedding = False # if embedding is trainable or random gaussian initialize, TODO: why diffusion lm can have term as lernable, cause collaps
x_0_prediction = True # if model predicts x_0 or x_{t-1}

# Model, trainer and loss function

In [90]:
class DistilBertModel(nn.Module):
  def __init__(self, embedding, projection, train_embedding=True, config=None) -> None:
    super().__init__()

    # self.model = DistilBertForMaskedLM.from_pretrained("./models/distilbert-base-uncased-local", local_files_only=True, config=config).to(device)
    self.model = DistilBertForMaskedLM(config).to(device)
    
    self.embedding = copy.deepcopy(embedding.requires_grad_(train_embedding))
    self.projection = copy.deepcopy(projection.requires_grad_(train_embedding))
    self.projection.bias.data = torch.zeros(self.projection.bias.data.shape, device=device)
    self.model.set_input_embeddings(nn.Sequential())
    self.model.set_output_embeddings(nn.Sequential())

    # print(self.model.config)

  def parameters(self):
    return self.model.parameters()
    # return list(model.model.parameters()) + list(model.embedding.parameters()) + list(model.projection.parameters())
  
  def forward(self, x, mask):
    '''
    return 
      feature_out, shape: [batch_size, seq_len, dim]
      vocab_out, shape: [batch_size, seq_len, vocab_size]
    '''
    
    x_out = self.model(x, mask)[0]
    return self.projection(x_out), x_out

class LinearModel(nn.Module):
  def __init__(self, embedding, projection, train_embedding=True, config=None) -> None:
    super().__init__()

    # self.model = DistilBertForMaskedLM.from_pretrained("./models/distilbert-base-uncased-local", local_files_only=True, config=config).to(device)
    self.model = nn.Linear(768, 768).to(device)
    
    self.embedding = copy.deepcopy(embedding.requires_grad_(train_embedding))
    self.projection = copy.deepcopy(projection.requires_grad_(train_embedding))
    self.projection.bias.data = torch.zeros(self.projection.bias.data.shape, device=device)

  def parameters(self):
    return self.model.parameters()

  def forward(self, x, mask):
    hidden = self.model(x)
    return self.projection(hidden), hidden
    
origin = DistilBertForMaskedLM.from_pretrained("./models/distilbert-base-uncased-local", local_files_only=True).to(device)

configuration = DistilBertConfig()
model = DistilBertModel(origin.get_input_embeddings(), origin.get_output_embeddings(), train_embedding=train_embedding, config=configuration)
# model = EncoderModel(train_embedding=train_embedding)
# model = BertModel(train_embedding=train_embedding)
# model = LinearModel(origin.get_input_embeddings(), origin.get_output_embeddings(), train_embedding=train_embedding)

if linear_probe:  
  # TODO: linear probation not supported
  NotImplementedError()
  # trainer = optim.Adam(model.projection.parameters(), lr=learning_rate)
else:
  # parameter only include model, no embedding layer
  # trainer = optim.Adam(model.parameters(), lr=learning_rate)
  trainer = optim.AdamW(model.parameters(), lr=learning_rate)


In [105]:
betas = torch.hstack([torch.zeros(1), torch.linspace(beta_min, beta_max, step_tot)]).to(device)
alphas = 1 - betas
alpha_cumprod = torch.cumprod(alphas[:-1], 0)
def diffuse_t(x, t):
  '''
  x_shape: [batch_size, seq_len, dim]
  t shape: [sample num]

  return shape [batch_size * sample_num, seq_len, dim]
  '''
  # TODO: change model to use different noise
  batch_size, seq_len, dim = x.shape
  sample_shape = (sample_size, *(1, ) * len(x.shape))

  noise = torch.normal(0, 1, x.shape).to(device)
  mean = torch.sqrt(alpha_cumprod[t].reshape(sample_shape)) * x 
  epsilon = noise * torch.sqrt(1 - alpha_cumprod[t]).reshape(sample_shape)
  return (mean + epsilon).reshape((sample_size * batch_size, seq_len, dim))

def generate_diffuse_pair(x_0, repeat_shape, t, t_next=-1):
  '''
  x_0 shape: [batch_size, seq_len, dim]
  t shape: [sample_num]
  repeat shape: (sample_num, 1, 1, ...)
  
  return (net input, net target)
    shape [batch_size * sample_num, seq_len, dim]
  '''
  if t_next == -1:
    # predict x_0
    return (diffuse_t(x_0, t), x_0.repeat(repeat_shape))

  # predict x_{t_next}
  return (diffuse_t(x_0, t), diffuse_t(x_0, t_next))

def loss(model, x_t, x_1, x_0, mask, idx, loss_func):
  ''' 
  input: 
    model,
    x_t, x_1, x_0 shape: [batch_size, seq_len, dim]
    mask
    seq shape: [batch_size, seq_len]TODO: seq only support one sample
    loss_func
  '''
  _, x_hat = model(x_t, mask)

  probability, x_0_hat = model(x_1, mask)

  idx = idx.unsqueeze(dim=-1)
  seq_probability_loss = -(nn.functional.softmax(probability, dim=-1)).gather(-1, idx).log().mean()
  
  return loss_func(x_hat, x_0), loss_func(x_0_hat, x_0), 0.1 * seq_probability_loss


# Define dataset

In [65]:
# define dataset 
class DPMDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, input_df):
        self.tokenizer = tokenizer
        self.texts = input_df['text'].tolist()

    def collate_fn(self, batch):
        # function for batch allocation
        texts = []

        for b in batch:
            texts.append(b)

        encodings = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)

        return {"input_ids": encodings["input_ids"].to(device), "attention_mask": encodings["attention_mask"].to(device)}

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

tokenizer = DistilBertTokenizer.from_pretrained("./tokenizers/distilbert-base-uncased-local/", local_files_only=True)
# tokenizer = BertTokenizer.from_pretrained("./tokenizers/bert-base-cased-local", local_files_only=True)

train_dataset = DPMDataset(tokenizer, train_df)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate_fn)


# Training

In [67]:
# training
# model = torch.load("model_continue1.pickle")["net"]
# trainer = optim.AdamW(model.parameters(), lr=learning_rate)
model.train()
print("start training")
for epoch in range(epoch_num):
  acc_loss = 0
  with tqdm.tqdm(train_loader, unit="batch") as tepoch: 
    for epoch, x in enumerate(tepoch):
  # for x in train_loader:
      x_0 = model.embedding(x["input_ids"])
      repeat_shape = (sample_size, *(1, ) * (len(x_0.shape) - 1))
      t = torch.randint(0, step_tot, repeat_shape, device=device)
      if x_0_prediction:
        x_input, x_tgt = generate_diffuse_pair(x_0, repeat_shape, t)
      else:
        print("not implemented")
        # x_input, x_tgt = generate_diffuse_pair(x_0, repeat_shape, t, torch.max(t - 30, 0))
      x_1_tgt = diffuse_t(x_0, torch.ones(1, dtype=torch.int64, device=device).repeat(repeat_shape))

      trainer.zero_grad()
      x_t_restore, x_1_restore, prob = loss(model, x_input, x_1_tgt, x_0, x["attention_mask"].repeat(repeat_shape), x["input_ids"], nn.L1Loss())
      l = x_t_restore + x_1_restore + prob
      l.backward()
      trainer.step()

      acc_loss += l

      tepoch.set_description(f"Epoch {epoch}")
      tepoch.set_postfix(
                        x_t_restore=x_t_restore.item(),
                         x_1_restore=x_1_restore.item(),
                         prob=prob.item(),
                         tot_loss=l.item())

  print(f"epoch {epoch} average loss: {acc_loss / len(train_loader)}, last loss x_t_restore, x_1_restore, prob: {x_t_restore, x_1_restore, prob}")


start training


Epoch 174:   3%|▎         | 175/6700 [00:09<05:57, 18.27batch/s, tot_loss=0.00151, x_1_restore=0.00151]


KeyboardInterrupt: 

In [92]:
# trial on inference
# model = torch.load("model_continue1.pickle")["net"]
# model.model.add_module("activation", activations.GELUActivation())
# model.model.weight.data = torch.eye(768).
# model.model.bias.data = torch.zeros(model.model.bias.data.shape)
model.eval()
model.to(device)
origin_text = train_df.loc[11]["text"]
text = tokenizer(origin_text, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(device)
print("origin text: ", origin_text)

x_0 = model.embedding(text["input_ids"])
repeat_shape = (sample_size, *(1, ) * (len(x_0.shape) - 1))
t = torch.randint(0, step_tot, repeat_shape, device=device)

noised_text = diffuse_t(x_0, t)
# x_1 = diffuse_t(x_0, torch.ones(1, dtype=torch.int64, device=device).repeat(repeat_shape))
x_t = diffuse_t(x_0, torch.tensor([25], dtype=torch.int64, device=device).repeat(repeat_shape))
print("noise added")
print("t", t)
print("x_0 ground truth: ", tokenizer.decode((x_0 @ model.projection.weight.data.T).argmax(dim=-1)[0]))

# multi-step inference
restored = x_t
for i in range(5):
  out, restored = model(restored, text["attention_mask"].repeat(repeat_shape)) 

  # print("inferred: ", tokenizer.decode(torch.softmax(out, dim=-1).argmax(dim=-1)[0]))
  print("inferred: ", tokenizer.decode(out.argmax(dim=-1)[0]))
  # print("loss", loss(model, noised_text, x_1, x_0.repeat(repeat_shape), text["attention_mask"].repeat(repeat_shape), text["input_ids"], nn.L1Loss()))

print("text t effectiveness")
# effectiveness of model on large t
for i in range(5, 500, 25):
  x_t = diffuse_t(x_0, torch.tensor([i], dtype=torch.int64, device=device).repeat(repeat_shape))
  out, _ = model(x_t, text["attention_mask"].repeat(repeat_shape)) 

  # print("inferred: ", tokenizer.decode(torch.softmax(out, dim=-1).argmax(dim=-1)[0]))
  print("t: ", i, "restore: ", tokenizer.decode(out.argmax(dim=-1)[0]))




origin text:  Critics have even taken to dobbing in Katrina Bungard to National Party leader Bill English when they see her sign-written car bearing her name and photo parked in disabled parks .
noise added
inferred:  [CLS] critics have even taken to dobbing in katrina bungard to national party leader bill english when they see her sign - written car bearing her name and photo parked in disabled parks. [SEP]
inferred:  [CLS] critics have even taken to dobbing in katrina bungard to national party leader bill english when they see her sign - written car bearing her name and photo parked in disabled parks. [SEP]
inferred:  [CLS] critics have even taken to dobbing in katrina bungard to national party leader bill english when they see her sign - written car bearing her name and photo parked in disabled parks. [SEP]
inferred:  [CLS] critics have even taken to dobbing in katrina bungard to national party leader bill english when they see her sign - written car bearing her name and photo parke

In [10]:
# save model
torch.save({"net": model.to(torch.device("cpu"))}, "model.pickle")