In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
class Num_Image_Regression(nn.Module):
    def __init__(self):
        super(Num_Image_Regression, self).__init__()
        self.fc1_num = nn.Linear(20, 128)  # 20 numerical features
        self.fc2_num = nn.Linear(128, 64)
        #replace the above two by a gru
        self.img_conv = mobilenet_v3_large(pretrained=True)
        self.gru_num = nn.GRU(20, 128, 2, batch_first=True)
        self.gru_img=nn.GRU(2048, 128, 2, batch_first=True)

        self.fc1_img = nn.Linear(2048, 128)  # ResNet output size
        self.fc2_img = nn.Linear(128, 64)

        self.fc1_combined = nn.Linear(128, 64)
        self.fc2_combined = nn.Linear(64, 18)  # Output size (18 for regression)

    def forward(self, x):
        x_num = x[:, :20]  # First 20 features are numerical
        x_img = x[:, 20:]  # The rest are image features
        x_num, _ = self.gru_num(x_num)
        x_img, _ = self.gru_img(x_img)
        x_num = torch.relu(self.fc1_num(x_num))
        x_num = torch.relu(self.fc2_num(x_num))

        x_img = torch.relu(self.fc1_img(x_img))
        x_img = torch.relu(self.fc2_img(x_img))

        x_combined = torch.cat((x_num, x_img), dim=1)
        x_combined = torch.relu(self.fc1_combined(x_combined))
        x_combined = self.fc2_combined(x_combined)

        return x_combined

class MultimodalNet(nn.Module):
    def __init__(self):
        super(MultimodalNet, self).__init__()
        self.fc1_num = nn.Linear(20, 128)  # 20 numerical features
        self.fc2_num = nn.Linear(128, 64)

        self.fc1_img = nn.Linear(2048, 128)  # ResNet output size
        self.fc2_img = nn.Linear(128, 64)

        self.fc1_combined = nn.Linear(128, 64)
        self.fc2_combined = nn.Linear(64, 18)  # Output size (18 for regression)

    def forward(self, x):
        x_num = x[:, :20]  # First 20 features are numerical
        x_img = x[:, 20:]  # The rest are image features

        x_num = torch.relu(self.fc1_num(x_num))
        x_num = torch.relu(self.fc2_num(x_num))

        x_img = torch.relu(self.fc1_img(x_img))
        x_img = torch.relu(self.fc2_img(x_img))

        x_combined = torch.cat((x_num, x_img), dim=1)
        x_combined = torch.relu(self.fc1_combined(x_combined))
        x_combined = self.fc2_combined(x_combined)

        return x_combined
