In [8]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load titles as string
titles = open('titles.txt', 'r', encoding='utf-8').read()
titles[0:999]

'CURRENT PEDIATRIC DRUGS+WORLD WAR II AND MAYO+IMPORTANCE OF BREAST BIOPSY INCISION IN FINAL OUTCOME OF BREAST RECONSTRUCTION+BIOCHEMICAL EVALUATION OF ADRENAL DYSFUNCTION: THE LABORATORY PERSPECTIVE+PANCREATIC PSEUDOCYST THAT COMPRESSED THE INFERIOR VENA CAVA AND RESULTED IN EDEMA OF THE LOWER EXTREMITIES+DERMATOLOGIC MANIFESTATIONS OF HUMAN IMMUNODEFICIENCY VIRUS INFECTION+SURGICAL TREATMENT OF POSTINFARCTION RUPTURE OF A PAPILLARY MUSCLE+THE EVOLUTION OF MANAGEMENT OF POSTOPERATIVE PAIN+MAYO FOUNDATION COURSES AND MEETINGS+COLONOSCOPY+COLONOSCOPY: DR. MACCARTY REPLIES+THE SPINAL CORD INJURED PATIENT: COMPREHENSIVE MANAGEMENT+MEDICAL AND SURGICAL DISEASES OF THE PANCREAS+DIARRHEAL DISEASES (CURRENT TOPICS IN GASTROENTEROLOGY SERIES)+THE FOOT IN DIABETES+FOR PATIENTS: YOU AND HIV: A DAY AT A TIME+CONTRIBUTION OF A MEASURE OF DISEASE COMPLEXITY (COMPLEX) TO PREDICTION OF OUTCOME AND CHARGES AMONG HOSPITALIZED PATIENTS+MEASUREMENT OF SMALL BOWEL AND COLONIC TRANSIT: INDICATIONS AND METH

In [2]:
# Assign characters to integers
chars = sorted(list(set(titles)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

In [3]:
# Summarize data
n_chars = len(titles)
n_vocab = len(chars)
print(f"Total Characters: {n_chars} ")
print(f"Total Vocab: {n_vocab}")

Total Characters: 427981 
Total Vocab: 45


In [4]:
# Prepare dataset of input to output pairs encoded as integers
seq_length = 100
dataX, dataY = [], []
for i in range(0, n_chars - seq_length, 1):
    seq_in = titles[i:i + seq_length]
    seq_out = titles[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  427881


In [5]:
# Reshape X to be [samples, time steps, features]
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab) # works better as number between 0 and 1
y = torch.tensor(dataY)
print(X.shape, y.shape)

torch.Size([427881, 100, 1]) torch.Size([427881])


In [6]:
# Define LSTM model
class CharModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=2, batch_first=True, dropout=0.2)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        # produce output
        x = self.linear(self.dropout(x))
        return x

In [9]:
# Define model parameters
n_epochs = 40
batch_size = 128
model = CharModel()

# Set up optimization, loss function, and batch
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size)
for inputs, targets in loader:
    inputs, targets = inputs.to(device), targets.to(device)

# Initialize values for keeping track of best model
best_model = None
best_loss = np.inf 

In [11]:
# Create save checkpoint function
def save_checkpoint(epoch, model, optimizer, loss):
    checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }
    filename = f"model_epoch_{epoch}.pt"
    torch.save(checkpoint, filename)

In [18]:
# Create function to load latest checkpoint
def load_latest_model(model, optimizer=None):
    latest_model = max(saved_models, key=lambda x: int(x.split('_')[2].split('.')[0]))
    current_model = torch.load(latest_model)
    
    model.load_state_dict(current_model['model_state_dict'])
    optimizer.load_state_dict(current_model['optimizer_state_dict'])
        
    return current_model['epoch'], model, optimizer

In [25]:
saved_models = [f for f in os.listdir() if f.startswith('model') and f.endswith('.pt')]
if saved_models:
    epoch, model, optimizer = load_latest_model(model, optimizer)
else:
    epoch = 0
    
    
model = model.to(device)

# Training loop
for current_epoch in range(epoch + 1, n_epochs):
    # Training
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch)
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch: %d: Cross-entropy: %.4f" % (current_epoch, loss))
        
    # Save checkpoint
    save_checkpoint(current_epoch, model, optimizer, loss)

Epoch: 29: Cross-entropy: 468531.7812
Epoch: 30: Cross-entropy: 463084.7188
Epoch: 31: Cross-entropy: 457788.6250
Epoch: 32: Cross-entropy: 459815.1250
Epoch: 33: Cross-entropy: 455446.7188
Epoch: 34: Cross-entropy: 455746.1562
Epoch: 35: Cross-entropy: 451877.4688
Epoch: 36: Cross-entropy: 452637.7188
Epoch: 37: Cross-entropy: 448450.6562
Epoch: 38: Cross-entropy: 453189.9375
Epoch: 39: Cross-entropy: 446414.4375


In [26]:
torch.save([best_model, char_to_int], "40model.pth")