In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn.functional import cross_entropy
import numpy as np
import re, string
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from collections import Counter
import json
import kagglehub

# Download latest version
path = kagglehub.dataset_download("zynicide/wine-reviews" )
# Parameters
VOCAB_SIZE = 10000
MAX_LEN = 80
EMBEDDING_DIM = 256
N_HEADS = 2
FF_DIM = 256
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Load the data
with open(path+"/winemag-data-130k-v2.json") as json_data:
    wine_data = json.load(json_data)
filtered_data = [
    "wine review : "
    + x["country"]
    + " : "
    + x["province"]
    + " : "
    + x["variety"]
    + " : "
    + x["description"]
    for x in wine_data
    if x["country"] is not None
    and x["province"] is not None
    and x["variety"] is not None
    and x["description"] is not None
]
# Tokenizer
class SimpleTokenizer:
    def __init__(self, texts, vocab_size):
        self.vocab_size = vocab_size
        self.counter = Counter()
        for text in texts:
            self.counter.update(text.split())
        self.vocab = [word for word, _ in self.counter.most_common(vocab_size - 2)]
        self.word2idx = {word: idx + 2 for idx, word in enumerate(self.vocab)}
        self.word2idx["<pad>"] = 0
        self.word2idx["<unk>"] = 1

    def encode(self, text):
        return [self.word2idx.get(word, 1) for word in text.split()]

# Data preparation
def pad_punctuation(s):
    s = re.sub(f"([{string.punctuation}])", r" \1 ", s)
    s = re.sub(" +", " ", s)
    return s.lower()


# Create tokenizer
tokenizer = SimpleTokenizer(filtered_data, VOCAB_SIZE)


In [2]:
# Dataset class
class WineDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = self.tokenizer.encode(self.texts[idx])[:self.max_len+1]
        padding = [self.tokenizer.word2idx["<pad>"]] * (self.max_len + 1 - len(tokens))
        tokens += padding
        return torch.tensor(tokens[:-1]), torch.tensor(tokens[1:])

train_dataset = WineDataset([pad_punctuation(t) for t in filtered_data], tokenizer, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [3]:
class GPTBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        seq_len = x.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(x.device)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)

        attn_output, _ = self.attn(x, x, x, attn_mask=mask)
        x = self.ln1(x + attn_output)
        ffn_output = self.ffn(x)
        return self.ln2(x + ffn_output)


In [4]:
class GPTModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len, num_heads, ff_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        self.transformer = GPTBlock(embed_dim, num_heads, ff_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_embedding(positions)
        x = self.transformer(x)
        return self.fc(x)


In [5]:
# Training loop
# Training
model = GPTModel(VOCAB_SIZE, EMBEDDING_DIM, MAX_LEN, N_HEADS, FF_DIM).to(DEVICE)
optimizer = Adam(model.parameters(), lr=0.0001)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)
        optimizer.zero_grad()
        logits = model(x_batch)
        loss = cross_entropy(logits.view(-1, VOCAB_SIZE), y_batch.view(-1), ignore_index=tokenizer.word2idx['<pad>'])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Training Loss: {avg_loss:.4f}")



Epoch 1/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.48it/s]


Epoch 1, Training Loss: 4.0331


Epoch 2/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.39it/s]


Epoch 2, Training Loss: 3.4174


Epoch 3/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.53it/s]


Epoch 3, Training Loss: 3.2060


Epoch 4/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.35it/s]


Epoch 4, Training Loss: 3.0849


Epoch 5/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.04it/s]


Epoch 5, Training Loss: 3.0026


Epoch 6/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.00it/s]


Epoch 6, Training Loss: 2.9418


Epoch 7/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.67it/s]


Epoch 7, Training Loss: 2.8937


Epoch 8/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.04it/s]


Epoch 8, Training Loss: 2.8538


Epoch 9/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.33it/s]


Epoch 9, Training Loss: 2.8195


Epoch 10/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.50it/s]


Epoch 10, Training Loss: 2.7895


Epoch 11/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.13it/s]


Epoch 11, Training Loss: 2.7633


Epoch 12/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.89it/s]


Epoch 12, Training Loss: 2.7401


Epoch 13/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.53it/s]


Epoch 13, Training Loss: 2.7191


Epoch 14/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.09it/s]


Epoch 14, Training Loss: 2.6998


Epoch 15/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.55it/s]


Epoch 15, Training Loss: 2.6820


Epoch 16/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.41it/s]


Epoch 16, Training Loss: 2.6651


Epoch 17/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.91it/s]


Epoch 17, Training Loss: 2.6499


Epoch 18/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.39it/s]


Epoch 18, Training Loss: 2.6356


Epoch 19/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.33it/s]


Epoch 19, Training Loss: 2.6223


Epoch 20/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.43it/s]


Epoch 20, Training Loss: 2.6099


Epoch 21/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.12it/s]


Epoch 21, Training Loss: 2.5983


Epoch 22/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.16it/s]


Epoch 22, Training Loss: 2.5876


Epoch 23/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.32it/s]


Epoch 23, Training Loss: 2.5772


Epoch 24/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.22it/s]


Epoch 24, Training Loss: 2.5677


Epoch 25/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.97it/s]


Epoch 25, Training Loss: 2.5587


Epoch 26/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.36it/s]


Epoch 26, Training Loss: 2.5501


Epoch 27/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.68it/s]


Epoch 27, Training Loss: 2.5422


Epoch 28/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.19it/s]


Epoch 28, Training Loss: 2.5345


Epoch 29/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.04it/s]


Epoch 29, Training Loss: 2.5272


Epoch 30/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.52it/s]


Epoch 30, Training Loss: 2.5203


Epoch 31/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.59it/s]


Epoch 31, Training Loss: 2.5137


Epoch 32/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.35it/s]


Epoch 32, Training Loss: 2.5075


Epoch 33/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.40it/s]


Epoch 33, Training Loss: 2.5014


Epoch 34/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.50it/s]


Epoch 34, Training Loss: 2.4957


Epoch 35/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.13it/s]


Epoch 35, Training Loss: 2.4903


Epoch 36/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.28it/s]


Epoch 36, Training Loss: 2.4851


Epoch 37/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.50it/s]


Epoch 37, Training Loss: 2.4799


Epoch 38/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.57it/s]


Epoch 38, Training Loss: 2.4751


Epoch 39/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.30it/s]


Epoch 39, Training Loss: 2.4704


Epoch 40/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.35it/s]


Epoch 40, Training Loss: 2.4658


Epoch 41/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.06it/s]


Epoch 41, Training Loss: 2.4614


Epoch 42/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.26it/s]


Epoch 42, Training Loss: 2.4573


Epoch 43/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.17it/s]


Epoch 43, Training Loss: 2.4532


Epoch 44/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.28it/s]


Epoch 44, Training Loss: 2.4493


Epoch 45/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.58it/s]


Epoch 45, Training Loss: 2.4454


Epoch 46/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.04it/s]


Epoch 46, Training Loss: 2.4419


Epoch 47/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.17it/s]


Epoch 47, Training Loss: 2.4383


Epoch 48/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.83it/s]


Epoch 48, Training Loss: 2.4347


Epoch 49/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.27it/s]


Epoch 49, Training Loss: 2.4313


Epoch 50/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.09it/s]


Epoch 50, Training Loss: 2.4280


Epoch 51/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.38it/s]


Epoch 51, Training Loss: 2.4249


Epoch 52/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.30it/s]


Epoch 52, Training Loss: 2.4218


Epoch 53/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.53it/s]


Epoch 53, Training Loss: 2.4186


Epoch 54/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.50it/s]


Epoch 54, Training Loss: 2.4159


Epoch 55/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.32it/s]


Epoch 55, Training Loss: 2.4129


Epoch 56/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.54it/s]


Epoch 56, Training Loss: 2.4101


Epoch 57/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.49it/s]


Epoch 57, Training Loss: 2.4074


Epoch 58/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.62it/s]


Epoch 58, Training Loss: 2.4047


Epoch 59/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.95it/s]


Epoch 59, Training Loss: 2.4020


Epoch 60/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.10it/s]


Epoch 60, Training Loss: 2.3997


Epoch 61/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.76it/s]


Epoch 61, Training Loss: 2.3971


Epoch 62/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.48it/s]


Epoch 62, Training Loss: 2.3947


Epoch 63/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.80it/s]


Epoch 63, Training Loss: 2.3922


Epoch 64/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.22it/s]


Epoch 64, Training Loss: 2.3899


Epoch 65/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.26it/s]


Epoch 65, Training Loss: 2.3877


Epoch 66/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.43it/s]


Epoch 66, Training Loss: 2.3854


Epoch 67/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.05it/s]


Epoch 67, Training Loss: 2.3833


Epoch 68/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.49it/s]


Epoch 68, Training Loss: 2.3810


Epoch 69/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.12it/s]


Epoch 69, Training Loss: 2.3789


Epoch 70/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.84it/s]


Epoch 70, Training Loss: 2.3769


Epoch 71/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.54it/s]


Epoch 71, Training Loss: 2.3749


Epoch 72/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.23it/s]


Epoch 72, Training Loss: 2.3729


Epoch 73/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.84it/s]


Epoch 73, Training Loss: 2.3709


Epoch 74/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.33it/s]


Epoch 74, Training Loss: 2.3691


Epoch 75/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.34it/s]


Epoch 75, Training Loss: 2.3671


Epoch 76/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.46it/s]


Epoch 76, Training Loss: 2.3654


Epoch 77/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.59it/s]


Epoch 77, Training Loss: 2.3636


Epoch 78/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.44it/s]


Epoch 78, Training Loss: 2.3616


Epoch 79/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.31it/s]


Epoch 79, Training Loss: 2.3599


Epoch 80/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.47it/s]


Epoch 80, Training Loss: 2.3583


Epoch 81/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.59it/s]


Epoch 81, Training Loss: 2.3565


Epoch 82/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.36it/s]


Epoch 82, Training Loss: 2.3547


Epoch 83/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.42it/s]


Epoch 83, Training Loss: 2.3531


Epoch 84/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.15it/s]


Epoch 84, Training Loss: 2.3516


Epoch 85/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.56it/s]


Epoch 85, Training Loss: 2.3500


Epoch 86/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.01it/s]


Epoch 86, Training Loss: 2.3485


Epoch 87/100: 100%|██████████| 4060/4060 [00:54<00:00, 74.99it/s]


Epoch 87, Training Loss: 2.3468


Epoch 88/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.44it/s]


Epoch 88, Training Loss: 2.3453


Epoch 89/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.28it/s]


Epoch 89, Training Loss: 2.3438


Epoch 90/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.38it/s]


Epoch 90, Training Loss: 2.3424


Epoch 91/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.49it/s]


Epoch 91, Training Loss: 2.3408


Epoch 92/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.38it/s]


Epoch 92, Training Loss: 2.3395


Epoch 93/100: 100%|██████████| 4060/4060 [00:54<00:00, 75.02it/s]


Epoch 93, Training Loss: 2.3379


Epoch 94/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.69it/s]


Epoch 94, Training Loss: 2.3365


Epoch 95/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.40it/s]


Epoch 95, Training Loss: 2.3351


Epoch 96/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.82it/s]


Epoch 96, Training Loss: 2.3338


Epoch 97/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.48it/s]


Epoch 97, Training Loss: 2.3325


Epoch 98/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.29it/s]


Epoch 98, Training Loss: 2.3311


Epoch 99/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.88it/s]


Epoch 99, Training Loss: 2.3297


Epoch 100/100: 100%|██████████| 4060/4060 [00:53<00:00, 75.64it/s]

Epoch 100, Training Loss: 2.3284





In [6]:
def generate_text(model, tokenizer, start_prompt, max_tokens=80, temperature=1.0):
    model.eval()
    tokens = tokenizer.encode(start_prompt)
    tokens = tokens[:MAX_LEN]
    generated = tokens.copy()

    with torch.no_grad():
        for _ in range(max_tokens):
            input_tensor = torch.tensor([generated[-MAX_LEN:]], device=DEVICE)
            logits = model(input_tensor)
            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            if next_token == tokenizer.word2idx['<pad>']:
                break

            generated.append(next_token)

    idx2word = {idx: word for word, idx in tokenizer.word2idx.items()}
    generated_text = ' '.join(idx2word.get(idx, '<unk>') for idx in generated)
    return generated_text

# Example usage:
prompt = "wine review : US"
generated_text = generate_text(model, tokenizer, prompt, max_tokens=50, temperature=0.8).strip()
print("Generated text:\n", generated_text)


Generated text:
 wine review : US : <unk> : <unk> : a <unk> : with a ripe , slight green apple character that gives the wine . the texture of the wine , perfumed , white fruits are lifted by the apricot fruits . the wine is bright , ready to drink . wait until 2017
