Skip to content

Commit 171d0a2

Browse files
fix: correct syntax error in main.py by fixing newlines
1 parent f8c6eef commit 171d0a2

File tree

1 file changed

+67
-1
lines changed

1 file changed

+67
-1
lines changed

src/main.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,67 @@
1-
print('Hello, World!')
1+
import torch
2+
import torch.nn as nn
3+
from torch.utils.data import DataLoader
4+
from torch.optim import Adam
5+
6+
from models import LearnableGatedPooling
7+
from preprocess import prepare_data
8+
from train import train_model
9+
from evaluate import evaluate_model
10+
11+
def main():
12+
# Configuration
13+
input_dim = 768 # Example: BERT embedding dimension
14+
batch_size = 32
15+
seq_len = 10
16+
num_epochs = 10
17+
learning_rate = 0.001
18+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
20+
# Initialize model
21+
model = LearnableGatedPooling(input_dim=input_dim, seq_len=seq_len)
22+
23+
# Example data (replace with your actual data loading)
24+
dummy_sequences = [torch.randn(seq_len, input_dim) for _ in range(100)]
25+
26+
# Preprocess data
27+
processed_data, max_seq_len = prepare_data(dummy_sequences, batch_size)
28+
29+
# Create dummy targets (replace with your actual targets)
30+
dummy_targets = torch.randn(100, input_dim)
31+
32+
# Create data loaders
33+
dataset = torch.utils.data.TensorDataset(processed_data, dummy_targets)
34+
train_size = int(0.8 * len(dataset))
35+
test_size = len(dataset) - train_size
36+
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
37+
38+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
39+
test_loader = DataLoader(test_dataset, batch_size=batch_size)
40+
41+
# Initialize optimizer and loss function
42+
optimizer = Adam(model.parameters(), lr=learning_rate)
43+
criterion = nn.MSELoss()
44+
45+
# Train model
46+
training_history = train_model(
47+
model=model,
48+
train_loader=train_loader,
49+
optimizer=optimizer,
50+
criterion=criterion,
51+
num_epochs=num_epochs,
52+
device=device
53+
)
54+
55+
# Evaluate model
56+
evaluation_results = evaluate_model(
57+
model=model,
58+
test_loader=test_loader,
59+
criterion=criterion,
60+
device=device
61+
)
62+
63+
print("\nTraining completed!")
64+
print(f"Final test loss: {evaluation_results['test_loss']:.4f}")
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)