In [None]:
from torch import nn
from torchvision import models
from torchsummary import summary

class Inception_LSTM(nn.Module):
    def __init__(self, num_classes, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
        super(Inception_LSTM, self).__init__()
        inception = models.inception_v3(pretrained=True)
        self.features = nn.Sequential(*list(inception.children())[:-3])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        features = self.features(x)
        features = self.avgpool(features)
        features = features.view(batch_size, seq_length, -1)
        lstm_out, _ = self.lstm(features)
        lstm_out = lstm_out[:, -1, :]
        out = self.fc(self.dropout(lstm_out))
        return out

# Initialize Inception_LSTM model
model_inception_lstm = Inception_LSTM(num_classes=2)

# Print model summary
summary(model_inception_lstm, input_size=(20, 3, 299, 299))
