|
1 |
| -print('Hello, World!') |
| 1 | +print('Hello, World!')import torch |
| 2 | +import torch.nn as nn |
| 3 | +from torch.utils.data import DataLoader |
| 4 | +from torch.optim import Adam |
| 5 | + |
| 6 | +from models import LearnableGatedPooling |
| 7 | +from preprocess import prepare_data |
| 8 | +from train import train_model |
| 9 | +from evaluate import evaluate_model |
| 10 | + |
| 11 | +def main(): |
| 12 | + # Configuration |
| 13 | + input_dim = 768 # Example: BERT embedding dimension |
| 14 | + batch_size = 32 |
| 15 | + seq_len = 10 |
| 16 | + num_epochs = 10 |
| 17 | + learning_rate = 0.001 |
| 18 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 19 | + |
| 20 | + # Initialize model |
| 21 | + model = LearnableGatedPooling(input_dim=input_dim, seq_len=seq_len) |
| 22 | + |
| 23 | + # Example data (replace with your actual data loading) |
| 24 | + dummy_sequences = [torch.randn(seq_len, input_dim) for _ in range(100)] |
| 25 | + |
| 26 | + # Preprocess data |
| 27 | + processed_data, max_seq_len = prepare_data(dummy_sequences, batch_size) |
| 28 | + |
| 29 | + # Create dummy targets (replace with your actual targets) |
| 30 | + dummy_targets = torch.randn(100, input_dim) |
| 31 | + |
| 32 | + # Create data loaders |
| 33 | + dataset = torch.utils.data.TensorDataset(processed_data, dummy_targets) |
| 34 | + train_size = int(0.8 * len(dataset)) |
| 35 | + test_size = len(dataset) - train_size |
| 36 | + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
| 37 | + |
| 38 | + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| 39 | + test_loader = DataLoader(test_dataset, batch_size=batch_size) |
| 40 | + |
| 41 | + # Initialize optimizer and loss function |
| 42 | + optimizer = Adam(model.parameters(), lr=learning_rate) |
| 43 | + criterion = nn.MSELoss() |
| 44 | + |
| 45 | + # Train model |
| 46 | + training_history = train_model( |
| 47 | + model=model, |
| 48 | + train_loader=train_loader, |
| 49 | + optimizer=optimizer, |
| 50 | + criterion=criterion, |
| 51 | + num_epochs=num_epochs, |
| 52 | + device=device |
| 53 | + ) |
| 54 | + |
| 55 | + # Evaluate model |
| 56 | + evaluation_results = evaluate_model( |
| 57 | + model=model, |
| 58 | + test_loader=test_loader, |
| 59 | + criterion=criterion, |
| 60 | + device=device |
| 61 | + ) |
| 62 | + |
| 63 | + print("\nTraining completed!") |
| 64 | + print(f"Final test loss: {evaluation_results['test_loss']:.4f}") |
| 65 | + |
| 66 | +if __name__ == "__main__": |
| 67 | + main() |
0 commit comments