In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from fan import FANLayer

class FAN_Classifier(nn.Module):
    def __init__(self, hidden_size, num_layers, dropout_rate):
        super().__init__()
        
        # Define a list of layers
        self.layers = nn.ModuleList()
        
        # Define the first layer
        self.layers.append(FANLayer(173, hidden_size))
        
        # Define the intermediate hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(FANLayer(hidden_size, hidden_size))
        
        # Final layer to output
        self.emo_output_layer = FANLayer(hidden_size, 6)
        self.strength_output_layer = FANLayer(hidden_size, 3)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        # Initialize weights
        self.init_weights()

    def init_weights(self):
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self,x):
        # Pass through hidden layers
        for layer in self.layers:
            x = layer(x)
            x = self.dropout(x)
        # Output layers
        emo_output = self.emo_output_layer(x)
        strength_output = self.strength_output_layer(x)
        
        return emo_output, strength_output