In [None]:
import torch
import torch.nn as nn

'''
The different models used on the different stages of the training process for the classifying part.
'''
# Network that adds the outputs of the features going through different networks.

class ClassifierTwoStreamAfter(nn.Module):

    def __init__(self, n_inputs=2048, h_rgb=1024,h_flow=1024, n_classes=9):
        super().__init__()
        self.fc_rgb = nn.Sequential(
                                 nn.Linear(n_inputs, h_rgb),       # input: Tensor 1x2048 = 2048 elements
                                 nn.ReLU(),
                                 nn.Dropout(0.8),
                                 nn.Linear(h_rgb, n_classes),
                                )
        self.fc_flow = nn.Sequential(
                                 nn.Linear(n_inputs, h_flow),       # input: Tensor 1x2048 = 2048 elements
                                 nn.ReLU(),
                                 nn.Dropout(0.8),
                                 nn.Linear(h_flow, n_classes),
                                )

    def forward(self,rgb,flow):
        rgb = self.fc_rgb(rgb)
        flow = self.fc_flow(flow)
        x = rgb + flow
        return x

# Network that joins the features before going through the network by addition.	
	
class ClassifierTwoStreamBefore(nn.Module):

    def __init__(self, n_inputs=2048, h_rgb=512,h_flow=512, n_classes=9):
        super().__init__()
        self.fc = nn.Sequential(
                                 nn.Linear(n_inputs, h_rgb),       # input: Tensor 1x2048 = 2048 elements
                                 nn.ReLU(),
                                 nn.Dropout(0.8),
                                 nn.Linear(h_rgb, n_classes),
                                )

    def forward(self,rgb,flow):
        x = rgb+flow
        x = self.fc(x)
        return x

# Network that concatenates the features.

class ClassifierTwoStreamConcat(nn.Module):

    def __init__(self, n_inputs=4096, h_rgb=1024,h_flow=512, n_classes=9):
        super().__init__()
        self.fc = nn.Sequential(
                                 nn.Linear(n_inputs, h_rgb),       # input: Tensor 1x2x2048 = 2x2048 elements
                                 nn.ReLU(),
                                 nn.Dropout(0.8),
                                 nn.Linear(h_rgb, n_classes),
                                )

    def forward(self,rgb,flow):
        x = torch.cat((rgb,flow),1)
        x = self.fc(x)
        return x

# Network that makes use only of the RGB features (spatial stream).

class ClassifierOneStream(nn.Module):

    def __init__(self, n_inputs=2048, hidden_sz=512, dropout=0.5, n_classes=9):
        super().__init__()
        self.fc = nn.Sequential(
                                nn.Linear(n_inputs, hidden_sz),       # input: Tensor 1x2048 = 2048 elements
                                nn.ReLU(),
                                nn.Dropout(dropout),
                                nn.Linear(hidden_sz, n_classes),
                               )

    def forward(self, x):
        x = self.fc(x)
        return x