Skip to content

Commit 14dfbc3

Browse files
feat: implement learnable gated pooling model
- Add LearnableGatedPooling model implementation - Add training, evaluation, and preprocessing utilities - Add configuration file - Update requirements.txt with dependencies
1 parent f8c6eef commit 14dfbc3

File tree

7 files changed

+226
-2
lines changed

7 files changed

+226
-2
lines changed

config/model_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Configuration for the Learnable Gated Pooling model
3+
"""
4+
5+
class ModelConfig:
6+
# Model parameters
7+
INPUT_DIM = 768 # Dimension of input vectors
8+
SEQ_LEN = 10 # Maximum sequence length
9+
10+
# Training parameters
11+
BATCH_SIZE = 32
12+
NUM_EPOCHS = 10
13+
LEARNING_RATE = 0.001
14+
15+
# Data split
16+
TRAIN_RATIO = 0.8
17+
18+
# Device configuration
19+
USE_CUDA = True # Will fall back to CPU if CUDA is not available

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
numpy
1+
torch>=2.0.0
2+
numpy>=1.21.0
3+
typing>=3.7.4

src/evaluate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from torch.utils.data import DataLoader
3+
from typing import Dict, Any
4+
from models import LearnableGatedPooling
5+
6+
def evaluate_model(
7+
model: LearnableGatedPooling,
8+
test_loader: DataLoader,
9+
criterion: torch.nn.Module,
10+
device: torch.device
11+
) -> Dict[str, float]:
12+
"""
13+
Evaluate the LearnableGatedPooling model.
14+
15+
Args:
16+
model: Trained LearnableGatedPooling model
17+
test_loader: DataLoader for test data
18+
criterion: Loss function
19+
device: Device to evaluate on (CPU/GPU)
20+
21+
Returns:
22+
Dictionary containing evaluation metrics
23+
"""
24+
model.eval()
25+
total_loss = 0.0
26+
27+
with torch.no_grad():
28+
for data, target in test_loader:
29+
data, target = data.to(device), target.to(device)
30+
output = model(data)
31+
loss = criterion(output, target)
32+
total_loss += loss.item()
33+
34+
avg_loss = total_loss / len(test_loader)
35+
return {'test_loss': avg_loss}

src/main.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,67 @@
1-
print('Hello, World!')
1+
print('Hello, World!')import torch
2+
import torch.nn as nn
3+
from torch.utils.data import DataLoader
4+
from torch.optim import Adam
5+
6+
from models import LearnableGatedPooling
7+
from preprocess import prepare_data
8+
from train import train_model
9+
from evaluate import evaluate_model
10+
11+
def main():
12+
# Configuration
13+
input_dim = 768 # Example: BERT embedding dimension
14+
batch_size = 32
15+
seq_len = 10
16+
num_epochs = 10
17+
learning_rate = 0.001
18+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
20+
# Initialize model
21+
model = LearnableGatedPooling(input_dim=input_dim, seq_len=seq_len)
22+
23+
# Example data (replace with your actual data loading)
24+
dummy_sequences = [torch.randn(seq_len, input_dim) for _ in range(100)]
25+
26+
# Preprocess data
27+
processed_data, max_seq_len = prepare_data(dummy_sequences, batch_size)
28+
29+
# Create dummy targets (replace with your actual targets)
30+
dummy_targets = torch.randn(100, input_dim)
31+
32+
# Create data loaders
33+
dataset = torch.utils.data.TensorDataset(processed_data, dummy_targets)
34+
train_size = int(0.8 * len(dataset))
35+
test_size = len(dataset) - train_size
36+
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
37+
38+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
39+
test_loader = DataLoader(test_dataset, batch_size=batch_size)
40+
41+
# Initialize optimizer and loss function
42+
optimizer = Adam(model.parameters(), lr=learning_rate)
43+
criterion = nn.MSELoss()
44+
45+
# Train model
46+
training_history = train_model(
47+
model=model,
48+
train_loader=train_loader,
49+
optimizer=optimizer,
50+
criterion=criterion,
51+
num_epochs=num_epochs,
52+
device=device
53+
)
54+
55+
# Evaluate model
56+
evaluation_results = evaluate_model(
57+
model=model,
58+
test_loader=test_loader,
59+
criterion=criterion,
60+
device=device
61+
)
62+
63+
print("\nTraining completed!")
64+
print(f"Final test loss: {evaluation_results['test_loss']:.4f}")
65+
66+
if __name__ == "__main__":
67+
main()

src/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class LearnableGatedPooling(nn.Module):
6+
def __init__(self, input_dim, seq_len):
7+
super(LearnableGatedPooling, self).__init__()
8+
self.weights = nn.Parameter(torch.ones(input_dim))
9+
self.gate_linear = nn.Linear(input_dim, 1) # Linear layer for gating
10+
11+
def forward(self, x):
12+
# x: (batch_size, seq_len, input_dim)
13+
weighted_x = x * self.weights
14+
gate_values = torch.sigmoid(self.gate_linear(x)).squeeze(2) # (batch_size, seq_len)
15+
gated_x = weighted_x * gate_values.unsqueeze(2)
16+
pooled_vector = torch.mean(gated_x, dim=1) # Average pooling
17+
return pooled_vector

src/preprocess.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from typing import List, Tuple
3+
4+
def prepare_data(sequences: List[torch.Tensor], batch_size: int) -> Tuple[torch.Tensor, int]:
5+
"""
6+
Prepare input sequences for the LearnableGatedPooling model.
7+
8+
Args:
9+
sequences: List of input sequences
10+
batch_size: Size of each training batch
11+
12+
Returns:
13+
Tuple containing:
14+
- Padded and batched sequences
15+
- Maximum sequence length
16+
"""
17+
# Get maximum sequence length
18+
max_seq_len = max(seq.size(0) for seq in sequences)
19+
20+
# Pad sequences to max_seq_len
21+
padded_sequences = []
22+
for seq in sequences:
23+
if seq.size(0) < max_seq_len:
24+
padding = torch.zeros(max_seq_len - seq.size(0), seq.size(1))
25+
padded_seq = torch.cat([seq, padding], dim=0)
26+
padded_sequences.append(padded_seq)
27+
else:
28+
padded_sequences.append(seq)
29+
30+
# Stack sequences into a single tensor
31+
batched_sequences = torch.stack(padded_sequences)
32+
33+
return batched_sequences, max_seq_len

src/train.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.utils.data import DataLoader
4+
from typing import Dict, Any
5+
from models import LearnableGatedPooling
6+
7+
def train_model(
8+
model: LearnableGatedPooling,
9+
train_loader: DataLoader,
10+
optimizer: torch.optim.Optimizer,
11+
criterion: nn.Module,
12+
num_epochs: int,
13+
device: torch.device
14+
) -> Dict[str, Any]:
15+
"""
16+
Train the LearnableGatedPooling model.
17+
18+
Args:
19+
model: LearnableGatedPooling model instance
20+
train_loader: DataLoader for training data
21+
optimizer: Optimizer for training
22+
criterion: Loss function
23+
num_epochs: Number of training epochs
24+
device: Device to train on (CPU/GPU)
25+
26+
Returns:
27+
Dictionary containing training history
28+
"""
29+
model.to(device)
30+
history = {'loss': []}
31+
32+
for epoch in range(num_epochs):
33+
model.train()
34+
epoch_loss = 0.0
35+
36+
for batch_idx, (data, target) in enumerate(train_loader):
37+
data, target = data.to(device), target.to(device)
38+
39+
optimizer.zero_grad()
40+
output = model(data)
41+
loss = criterion(output, target)
42+
43+
loss.backward()
44+
optimizer.step()
45+
46+
epoch_loss += loss.item()
47+
48+
avg_epoch_loss = epoch_loss / len(train_loader)
49+
history['loss'].append(avg_epoch_loss)
50+
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}')
51+
52+
return history

0 commit comments

Comments
 (0)