In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import BertModel,BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class ClimateDataEncoder(nn.Module):
    def __init__(self, input_channels, out_channels, output_dim):
        super(ClimateDataEncoder, self).__init__()
        
        self.conv1 = nn.Conv2d(input_channels, out_channels, kernel_size=3, stride=1, padding="same")
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.convpool = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding="same"),
            nn.ReLU()
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.ModuleList([self.convpool for _ in range(2)])
        self.fc1 = nn.Linear(1200,output_dim)
        self.fc2 = nn.Linear(768, output_dim)

        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    

    def encode_climate_data(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = F.relu(x)
        for conv in self.conv2:
            x = conv(x)
        x = self.maxpool2(x)
        x = x.view(x.size(0), x.size(1), -1)
        x = x.mean(dim=1)
        x = self.fc1(x)
        return x
    
    def encode_text(self, text):
        inputs = self.bert_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        outputs = self.bert_model(**inputs)
        pooled_output = outputs.pooler_output
        pooled_output = self.fc2(pooled_output)
        return pooled_output

    def forward(self, climate_data,text):
        climate_features = self.encode_climate_data(climate_data)
        text_features = self.encode_text(text)
        
        return climate_features,text_features

        
    


In [7]:
input_climate = torch.randn(10,3, 121, 161)  # Example input tensor
input_text = ["Climate data example text"] * 10  
model = ClimateDataEncoder(input_channels=3, out_channels=32, output_dim=512)  # Example dimensions
output = model(input_climate,input_text)  # Forward pass

In [9]:
output[0].shape, output[1].shape  # Output shapes for climate features and text features

(torch.Size([10, 512]), torch.Size([10, 512]))