In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the LSTM-based classifier model
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        h0 = torch.zeros(self.lstm.num_layers, x.size(0), self.lstm.hidden_size).to(x.device)
        c0 = torch.zeros(self.lstm.num_layers, x.size(0), self.lstm.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]  # Take the output from the last time step
        out = self.fc(out)
        return out

# Load hidden states
hidden_states = torch.load('hidden_states.ViT.pt')

# Convert hidden states to torch tensor
hidden_states = torch.tensor(hidden_states, dtype=torch.float32)

# Example target labels for training (replace with actual labels)
# Assuming there are 10 classes and hidden states shape is (num_samples, seq_length, feature_dim)
num_samples, seq_length, feature_dim = hidden_states.shape
num_classes = 10  # Adjust based on your classification task
targets = torch.randint(0, num_classes, (num_samples,))

# Define the LSTM-based classifier
hidden_dim = 256
num_layers = 2
classifier = LSTMClassifier(input_dim=feature_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_classes=num_classes)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

# Training the classifier
epochs = 10
classifier.train()
for epoch in range(epochs):
    epoch_loss = 0
    for i in tqdm(range(num_samples), desc=f"Epoch {epoch + 1}/{epochs}"):
        optimizer.zero_grad()
        
        # Select the hidden states for the current sample
        input_data = hidden_states[i].unsqueeze(0)  # Add batch dimension
        target = targets[i].unsqueeze(0)
        
        # Forward pass
        output = classifier(input_data)
        
        # Compute loss
        loss = criterion(output, target)
        epoch_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / num_samples}")

# Save the trained classifier model
torch.save(classifier.state_dict(), 'lstm_classifier_model.pth')
logging.info("LSTM classifier model saved to lstm_classifier_model.pth")


  hidden_states = torch.tensor(hidden_states, dtype=torch.float32)
Epoch 1/10: 100%|██████████| 1/1 [00:00<00:00,  7.88it/s]
2024-06-28 17:01:08,541 - INFO - Epoch 1/10, Loss: 2.3007333278656006
Epoch 2/10: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]
2024-06-28 17:01:08,675 - INFO - Epoch 2/10, Loss: 2.1291894912719727
Epoch 3/10: 100%|██████████| 1/1 [00:00<00:00,  7.87it/s]
2024-06-28 17:01:08,803 - INFO - Epoch 3/10, Loss: 1.8556734323501587
Epoch 4/10: 100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
2024-06-28 17:01:08,929 - INFO - Epoch 4/10, Loss: 1.3017315864562988
Epoch 5/10: 100%|██████████| 1/1 [00:00<00:00,  7.78it/s]
2024-06-28 17:01:09,059 - INFO - Epoch 5/10, Loss: 0.4733865261077881
Epoch 6/10: 100%|██████████| 1/1 [00:00<00:00,  7.69it/s]
2024-06-28 17:01:09,190 - INFO - Epoch 6/10, Loss: 0.08597078174352646
Epoch 7/10: 100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
2024-06-28 17:01:09,318 - INFO - Epoch 7/10, Loss: 0.022380398586392403
Epoch 8/10: 100%|██████████| 1/1 [

In [2]:
classifier.state_dict()

OrderedDict([('lstm.weight_ih_l0',
              tensor([[ 0.0451,  0.0646, -0.0487,  ...,  0.0495,  0.0454,  0.0267],
                      [-0.0270, -0.0241,  0.0064,  ...,  0.0238, -0.0332, -0.0189],
                      [-0.0216,  0.0276,  0.0637,  ..., -0.0586,  0.0441, -0.0132],
                      ...,
                      [-0.0433,  0.0146,  0.0510,  ..., -0.0431, -0.0415,  0.0565],
                      [ 0.0479,  0.0504,  0.0657,  ...,  0.0137, -0.0403, -0.0152],
                      [ 0.0342, -0.0361, -0.0467,  ...,  0.0110, -0.0389, -0.0284]])),
             ('lstm.weight_hh_l0',
              tensor([[ 0.0446, -0.0529,  0.0470,  ...,  0.0509, -0.0126,  0.0320],
                      [ 0.0141,  0.0589, -0.0005,  ...,  0.0606,  0.0364, -0.0175],
                      [ 0.0671,  0.0305,  0.0130,  ..., -0.0040,  0.0197, -0.0067],
                      ...,
                      [ 0.0237, -0.0113, -0.0241,  ...,  0.0621, -0.0139, -0.0629],
                      [-0.0541,  

In [3]:
import torch
import torch.nn as nn

# Load precomputed hidden states
hidden_states = torch.load('hidden_states.Vit.pt')  # Shape: (batch_size, seq_length, hidden_dim * 2)

# Define a classifier model
class LSTMClassifier(nn.Module):
    def __init__(self, hidden_dim=512, num_classes=10):
        super(LSTMClassifier, self).__init__()
        self.fc = nn.Linear(hidden_dim * 2, num_classes)  # Multiply by 2 for bidirectional LSTM
    
    def forward(self, x):
        return self.fc(x)

# Train classifier
num_classes = 5  # Example number of classes, can be tuned
classifier = LSTMClassifier(hidden_dim=512, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

# Example target labels for training
target = torch.randint(0, num_classes, (hidden_states.size(0), hidden_states.size(1)))

# Training loop
epochs = 10
for epoch in range(epochs):
    optimizer.zero_grad()
    output = classifier(hidden_states)  # Shape: (batch_size, seq_length, num_classes)
    loss = criterion(output.view(-1, num_classes), target.view(-1))
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')


Epoch 1/10, Loss: 1.6061102151870728
Epoch 2/10, Loss: 1.5997072458267212
Epoch 3/10, Loss: 1.5961003303527832
Epoch 4/10, Loss: 1.5945404767990112
Epoch 5/10, Loss: 1.5939780473709106
Epoch 6/10, Loss: 1.5937540531158447
Epoch 7/10, Loss: 1.5935587882995605
Epoch 8/10, Loss: 1.5932397842407227
Epoch 9/10, Loss: 1.5927554368972778
Epoch 10/10, Loss: 1.5921286344528198


In [4]:
classifier.state_dict()

OrderedDict([('fc.weight',
              tensor([[ 0.0330, -0.0013,  0.0235,  ...,  0.0345, -0.0162,  0.0071],
                      [-0.0045, -0.0024,  0.0265,  ..., -0.0199, -0.0076, -0.0215],
                      [ 0.0223, -0.0006, -0.0237,  ..., -0.0222,  0.0321, -0.0128],
                      [-0.0111,  0.0236,  0.0054,  ..., -0.0127,  0.0330,  0.0300],
                      [-0.0327,  0.0120, -0.0040,  ..., -0.0047, -0.0002,  0.0152]])),
             ('fc.bias',
              tensor([ 0.0233,  0.0280, -0.0015,  0.0042,  0.0262]))])

In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the LSTM-based classifier model
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        h0 = torch.zeros(self.lstm.num_layers, x.size(0), self.lstm.hidden_size).to(x.device)
        c0 = torch.zeros(self.lstm.num_layers, x.size(0), self.lstm.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]  # Take the output from the last time step
        out = self.fc(out)
        return out

# Load hidden states
hidden_states = torch.load('hidden_states.ViT.pt')

# Convert hidden states to torch tensor
hidden_states = torch.tensor(hidden_states, dtype=torch.float32)

# Example target labels for training (replace with actual labels)
# Assuming there are 10 classes and hidden states shape is (num_samples, seq_length, feature_dim)
num_samples, seq_length, feature_dim = hidden_states.shape
num_classes = 10  # Adjust based on your classification task
targets = torch.randint(0, num_classes, (num_samples,))

# Define the LSTM-based classifier
hidden_dim = 256
num_layers = 2
classifier = LSTMClassifier(input_dim=feature_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_classes=num_classes)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

# Training the classifier
epochs = 10
classifier.train()
for epoch in range(epochs):
    epoch_loss = 0
    for i in tqdm(range(num_samples), desc=f"Epoch {epoch + 1}/{epochs}"):
        optimizer.zero_grad()
        
        # Select the hidden states for the current sample
        input_data = hidden_states[i].unsqueeze(0)  # Add batch dimension
        target = targets[i].unsqueeze(0)
        
        # Forward pass
        output = classifier(input_data)
        
        # Compute loss
        loss = criterion(output, target)
        epoch_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / num_samples}")

# Save the trained classifier model
torch.save(classifier.state_dict(), 'lstm_classifier_model.pth')
logging.info("LSTM classifier model saved to lstm_classifier_model.pth")

# Evaluating the model (Example)
classifier.eval()
with torch.no_grad():
    sample_data = hidden_states[0].unsqueeze(0)  # Select the first sample
    logits = classifier(sample_data)
    probabilities = torch.softmax(logits, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1)

    print("Logits:\n", logits)
    print("Probabilities:\n", probabilities)
    print("Predicted Class:\n", predicted_class)


  hidden_states = torch.tensor(hidden_states, dtype=torch.float32)
Epoch 1/10: 100%|██████████| 1/1 [00:00<00:00,  6.31it/s]
2024-06-28 17:01:29,652 - INFO - Epoch 1/10, Loss: 2.280545473098755
Epoch 2/10: 100%|██████████| 1/1 [00:00<00:00,  8.46it/s]
2024-06-28 17:01:29,771 - INFO - Epoch 2/10, Loss: 2.1285898685455322
Epoch 3/10: 100%|██████████| 1/1 [00:00<00:00,  7.38it/s]
2024-06-28 17:01:29,908 - INFO - Epoch 3/10, Loss: 1.8932427167892456
Epoch 4/10: 100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
2024-06-28 17:01:30,029 - INFO - Epoch 4/10, Loss: 1.4138613939285278
Epoch 5/10: 100%|██████████| 1/1 [00:00<00:00,  8.48it/s]
2024-06-28 17:01:30,148 - INFO - Epoch 5/10, Loss: 0.6126784086227417
Epoch 6/10: 100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
2024-06-28 17:01:30,266 - INFO - Epoch 6/10, Loss: 0.1227807104587555
Epoch 7/10: 100%|██████████| 1/1 [00:00<00:00,  8.63it/s]
2024-06-28 17:01:30,383 - INFO - Epoch 7/10, Loss: 0.029265454038977623
Epoch 8/10: 100%|██████████| 1/1 [00

Logits:
 tensor([[-1.9154, -0.3192, -1.6114,  7.6093, -1.7138, -1.6429, -1.8956, -1.6316,
         -2.1856, -1.3644]])
Probabilities:
 tensor([[7.2948e-05, 3.5994e-04, 9.8869e-05, 9.9893e-01, 8.9239e-05, 9.5799e-05,
         7.4403e-05, 9.6889e-05, 5.5675e-05, 1.2656e-04]])
Predicted Class:
 tensor([3])


In [10]:
hidden_states

NpzFile 'hidden_states.ViT.pt' with keys: hidden_states.ViT/data.pkl, hidden_states.ViT/byteorder, hidden_states.ViT/data/0, hidden_states.ViT/version, hidden_states.ViT/.data/serialization_id