In [None]:
!pip install torch
!pip install transformers

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer
import string

# the character based cnn required in the paper
class CharCNN(nn.Module):
    def __init__(self, num_chars, char_embed_size, num_filters, kernel_sizes):
        super(CharCNN, self).__init__()
        self.char_embedding = nn.Embedding(num_chars, char_embed_size)
        self.convs = nn.ModuleList([
            nn.Conv1d(char_embed_size, num_filters, kernel_size=k) # 1 dimensional convolution, not 2d like in 691.
            for k in kernel_sizes
        ])

    def forward(self, x):
        '''
        x's shape is (batch, word_len, char_embed_size)
        '''
        x = self.char_embedding(x)
        # print(x)
        x = x.transpose(1, 2)  # (batch, char_embed_size, word_len)
        x = [F.relu(conv(x)) for conv in self.convs]  # cnn
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1)  # join measurements
        return x

# full model, using lstm as highway and cnn from previous function.
class CharAwareLM(nn.Module):
    def __init__(self, num_chars, char_embed_size, num_filters, kernel_sizes,
                 word_vocab_size, word_embed_size, hidden_size, num_layers, dropout):
        super(CharAwareLM, self).__init__()
        self.char_cnn = CharCNN(num_chars, char_embed_size, num_filters, kernel_sizes)
        self.word_embedding = nn.Embedding(word_vocab_size, word_embed_size)
        self.lstm = nn.LSTM(char_embed_size + word_embed_size, hidden_size, num_layers, dropout=dropout)
        self.decoder = nn.Linear(hidden_size, word_vocab_size)

    def forward(self, word_input, char_input, hidden):
        char_output = self.char_cnn(char_input)  # Char-CNN output
        print(char_output.shape)
        # If CharCNN output is 2D (batch_size, features), unsqueeze to 3D (batch_size, seq_len, features)
        if char_output.dim() == 2:
            char_output = char_output.unsqueeze(1)

        word_embedding = self.word_embedding(word_input)  # Word embedding output
        
        print(word_embedding.shape)
        print(char_output.shape)

        # Ensure char_output's seq_len dimension matches that of word_embedding
        if word_embedding.size(1) != char_output.size(1):
            # This assumes char_output has seq_len of 1 and needs to be repeated for each word in the sequence
            char_output = char_output.repeat(1, word_embedding.size(1), 1)

        print(word_embedding.shape)
        print(char_output.shape)

        combined_emb = torch.cat((word_embedding, char_output), 2)  # combine
        lstm_out, hidden = self.lstm(combined_emb, hidden)  # LSTM output
        logits = self.decoder(lstm_out)  # Decode to word space
        return logits, hidden

In [19]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
char_dict = {char: idx + 1 for idx, char in enumerate(string.ascii_letters + string.digits + string.punctuation)}
char_dict['<pad>'] = 0
char_dict['<unk>'] = -1

# Prepare model hyperparameters
num_chars = len(char_dict)
char_embed_size = 50
num_filters = 100
kernel_sizes = [3, 4, 5]
word_vocab_size = tokenizer.vocab_size
word_embed_size = 768  # BERT-base hidden size
hidden_size = 512
num_layers = 2
dropout = 0.1

def simple_tokenize(sentence):
    # Include space as a separate token
    return [token for token in sentence.split(' ')] + [' ']

def char_indices(word, char_dict):
    return [char_dict.get(c, char_dict['<unk>']) for c in word] # make unknowns default return

def pad_sequences(sequences, maxlen, padding='post'):
    # Pads sequences to the same length
    num_instances = len(sequences)
    x = torch.zeros((num_instances, maxlen), dtype=torch.long)
    for i, seq in enumerate(sequences):
        if len(seq) != 0:
            if padding == 'pre':
                x[i, -len(seq):] = torch.tensor(seq[:maxlen], dtype=torch.long)
            else:
                x[i, :len(seq)] = torch.tensor(seq[:maxlen], dtype=torch.long)
    return x

def sentence_to_char_toks(sentence):
    word_tokens = tokenizer.tokenize(sentence)
    word_ids = tokenizer.convert_tokens_to_ids(word_tokens)
    word_input = tokenizer(sentence, return_tensors="pt")['input_ids']
    char_inputs = [char_indices(word, char_dict) for word in word_tokens]
    char_input = pad_sequences(char_inputs, maxlen=max(len(word) for word in word_tokens), padding='post')
    return word_input, char_input

model = CharAwareLM(
    num_chars=num_chars, char_embed_size=char_embed_size,
    num_filters=num_filters, kernel_sizes=kernel_sizes,
    word_vocab_size=word_vocab_size, word_embed_size=word_embed_size,
    hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)

sentence = "hello world"
word_input, char_input = sentence_to_char_toks(sentence)

hidden = None  # Should be initialized properly if using an LSTM/GRU

# Forward pass through the model
with torch.no_grad():
    logits, hidden = model(word_input, char_input, hidden)

print("Logits output from the model:", logits)

torch.Size([2, 300])
torch.Size([1, 4, 768])
torch.Size([2, 1, 300])
torch.Size([1, 4, 768])
torch.Size([2, 4, 300])


RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 2 for tensor number 1 in the list.

In [None]:
class CharAwareLM(nn.Module):
    def __init__(self, num_chars, char_embed_size, num_filters, kernel_sizes,
                 hidden_size, num_layers, dropout):
        super(CharAwareLM, self).__init__()
        self.char_cnn = CharCNN(num_chars, char_embed_size, num_filters, kernel_sizes)
        self.lstm = nn.LSTM(num_filters * len(kernel_sizes), hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.decoder = nn.Linear(hidden_size, num_chars)

    def forward(self, char_input, hidden=None):
        char_output = self.char_cnn(char_input)
        lstm_out, hidden = self.lstm(char_output.unsqueeze(1), hidden)
        logits = self.decoder(lstm_out.squeeze(1))
        return logits, hidden

# Function to convert sentence to character indices
def sentence_to_char_indices(sentence, char_dict):
    return [char_dict.get(c, char_dict['<unk>']) for c in sentence]

In [9]:
# Create a CharCNN instance
char_cnn = CharCNN(num_chars=num_chars, char_embed_size=char_embed_size, num_filters=num_filters, kernel_sizes=kernel_sizes)

# Function to convert sentence to character indices
def sentence_to_char_input(sentence, char_dict):
    char_seqs = []
    for word in sentence.split():
        char_seq = char_indices(word, char_dict)
        char_seqs.append(char_seq)
    char_input = pad_sequences(char_seqs, maxlen=max(len(word) for word in sentence.split()), padding='post')
    return char_input

# Prepare a sample sentence
test_sentence = "Hello World"
char_input = sentence_to_char_input(test_sentence, char_dict)

# Add an extra dimension for batch_size since CharCNN expects a batch of words
# char_input = char_input.unsqueeze(0) # would also work here if it's a single sequence
# char_input = char_input[None, :]

# Test the CharCNN
with torch.no_grad():
    char_cnn_output = char_cnn(char_input)

print("CharCNN output:", char_cnn_output)

CharCNN output: tensor([[6.8491e-01, 2.0344e-01, 1.0650e+00, 8.9655e-01, 7.2934e-01, 4.1404e-01,
         4.8933e-01, 6.6406e-01, 7.4998e-01, 5.8523e-01, 3.4747e-01, 4.8025e-01,
         2.0250e-01, 6.4700e-01, 6.4783e-01, 1.0188e+00, 4.8314e-01, 8.3459e-01,
         3.1567e-01, 3.6570e-02, 0.0000e+00, 4.3364e-01, 7.1529e-01, 4.5174e-01,
         8.1316e-01, 1.1642e-01, 2.2513e-01, 3.0454e-01, 2.0519e-01, 1.1365e-01,
         6.4268e-01, 8.2018e-01, 5.4768e-01, 4.2836e-02, 0.0000e+00, 8.9296e-01,
         2.4308e-01, 3.9687e-01, 8.6700e-02, 3.9741e-01, 6.9783e-01, 6.9034e-01,
         4.9028e-01, 0.0000e+00, 4.7883e-03, 5.3528e-01, 2.6567e-01, 0.0000e+00,
         1.2820e-01, 3.1328e-01, 4.4955e-01, 6.0443e-01, 4.3893e-01, 5.2464e-01,
         8.2506e-01, 7.0163e-01, 5.6531e-01, 8.6727e-01, 3.0726e-01, 3.7456e-01,
         0.0000e+00, 8.9690e-01, 8.7869e-01, 2.0284e-01, 0.0000e+00, 9.8108e-01,
         3.1052e-01, 7.3248e-01, 0.0000e+00, 3.7946e-01, 4.5382e-01, 2.4551e-01,
         3.6