In [1]:
!pip install torch torchvision torchaudio



In [3]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)  # Para compatibilidade com batches

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :].to(x.device)

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers, vocab_size, max_len):
        super(TransformerEncoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        self.positional_encoding = PositionalEncoding(embed_dim, max_len)
        
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dim_feedforward=ff_dim, 
            dropout=0.1, 
            activation='relu'
        )
        self.encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        
        self.layer_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.embedding(x)  
        
        x = self.positional_encoding(x)
        
        x = self.layer_norm(x)
        
        x = self.encoder(x)
        
        x = self.fc_out(x)

        return x

In [11]:
import torch
import torch.optim as optim
import torch.nn as nn

# Hiperparâmetros
embed_dim = 512  
num_heads = 8  
ff_dim = 2048  
num_layers = 6  
vocab_size = 10000 
max_len = 100  
learning_rate = 1e-4
batch_size = 64
num_epochs = 5

model = TransformerEncoder(embed_dim, num_heads, ff_dim, num_layers, vocab_size, max_len)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore_index para lidar com padding no texto




In [12]:
model

TransformerEncoder(
  (embedding): Embedding(10000, 512)
  (positional_encoding): PositionalEncoding()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=512, out_features=10000, bias=True)
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

In [13]:
def train(model, train_dataloader, optimizer, criterion):
    model.train()  
    total_loss = 0

    for batch in train_dataloader:
        optimizer.zero_grad()  

        inputs, labels = batch 

        outputs = model(inputs)

        loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))

        loss.backward() 
        optimizer.step()  

        total_loss += loss.item()

    return total_loss / len(train_dataloader)


In [None]:
def validate(model, val_dataloader, criterion):
    model.eval()  
    total_loss = 0

    with torch.no_grad():  
        for batch in val_dataloader:
            inputs, labels = batch
            outputs = model(inputs)

            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            total_loss += loss.item()

    return total_loss / len(val_dataloader)
