In [2]:

import geopandas as gpd
import numpy as np
import pandas as pd
import sklearn 
import torch
import torch.nn as nn
import torch.optim as optim
import random

In [4]:
#Reading the Data File

file_path = '/Users/shalini/Desktop/TATC-RL/clustered_data_4months (2).geojson'
data = gpd.read_file(file_path)


In [16]:

data.head(), data.columns

print(len(data))
#n_valid = [x for x in data if data['cnprcp_mean'] > 0]
print(len(data[data['cnprcp_mean']>0]))

94480
3630


In [4]:
# Preprocessing timestamp
data['time'] = pd.to_datetime(data['time'])


data['time_step'] = (data['time'] - data['time'].min()).dt.total_seconds()


In [5]:
'''
world = gpd.read_file('/Users/shalini/Desktop/TATC-RL/ne_110m_admin_0_countries.shp')
data['centroid'] = data.geometry.centroid
centroids = gpd.GeoDataFrame(data, geometry='centroid', crs=data.crs)
land_check = gpd.overlay(centroids, world, how='intersection')
data['is_ground'] = data['centroid'].apply(lambda x: 'land' if not land_check[land_check.geometry == x].empty else 'water')

data.head(), data.columns
'''

"\nworld = gpd.read_file('/Users/shalini/Desktop/TATC-RL/ne_110m_admin_0_countries.shp')\ndata['centroid'] = data.geometry.centroid\ncentroids = gpd.GeoDataFrame(data, geometry='centroid', crs=data.crs)\nland_check = gpd.overlay(centroids, world, how='intersection')\ndata['isground'] = data['centroid'].apply(lambda x: 'land' if not land_check[land_check.geometry == x].empty else 'water')\n\ndata.head(), data.columns\n"

In [6]:
#Data Processing

from sklearn.preprocessing import LabelEncoder, MinMaxScaler

#data['is_ground_enc'] = data['is_ground'].astype(int)

scaler = MinMaxScaler()
data[['lat_norm', 'lon_norm', 'cnprcp_norm']] = scaler.fit_transform(data[['lat_sat', 'lon_sat', 'cnprcp_mean']])

In [7]:
#Creating RL environment

class SatelliteEnv:
    def __init__(self, data):
        #defining state space
        self.data = data.reset_index(drop=True)
        self.current_index = 0
        
        self.state_space = [
            'lat_norm',
            'lon_norm',
            'time_step'
            'solar_hour'
        ]
        
        self.n_actions = 2
        
    def reset(self):
        #Environment resets to the initial state.
        self.current_index = 0
        self.current_time = self.data.loc[self.current_index, 'time_step']
        return self.get_state()

    def get_state(self):
        #get back current state vector based on real-time data
        if self.current_index < len(self.data):
            state = self.data.loc[self.current_index, self.state_space].values.astype(float)
            return state
        else:
            return None
    
    def step(self, action):
        
        #pushing action and transition to next state
        
        #action taken by agent will be in form of 0 and 1 only i.e valid observation or not. 
        state = self.get_state()
        cnprcp_norm = state[self.state_space.index('cnprcp_norm')]
        reward = self.calculate_reward(cnprcp_norm, action)
        
        current_time = self.data.loc[self.current_index, 'time_step']
        
        self.current_index += 1
        
        #whether the episode environment has ended.
        done = self.current_index >= len(self.data)
        if not done == True:
            next_time = self.data.loc[self.current_index, 'time_step']
            
            while next_time <= current_time and self.current_index < len(self.data) - 1:
                self.current_index += 1
                next_time = self.data.loc[self.current_index, 'time_step']

            if next_time <= current_time:
                done = True
                next_state = None
            else:
                self.current_time = next_time
                next_state = self.get_state()
                done = False
        else:
            next_state = None
        
        return next_state, reward, done
    
    #basic reward function
    def calculate_reward(self, cnprcp_norm, action, threshold=0.0):
        if action == 1:
            if cnprcp_norm > threshold:
                return 1
            else:
                return -1
        else:
            if cnprcp_norm > threshold:
                return -0.5
            else:
                return 0


In [11]:
class DQN(nn.Module):
    def init(self, input_dim, output_dim):
        super(DQN, self).init()
        # Define the neural network layers
        self.fc1 = nn.Linear(input_dim, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        return self.fc3(x)


In [12]:
from collections import deque
import random

class ReplayMemory:
    def init(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)

    def push(self, transition):
        self.memory.append(transition)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    

In [10]:
class DQNAgent:
    def init(self, state_dim, action_dim, lr=0.001):
        self.model = DQN(state_dim, action_dim)
        self.target_model = DQN(state_dim, action_dim)
        self.update_target_network()
        self.memory = ReplayMemory(10000)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()
        self.action_dim = action_dim
        self.gamma = 0.99
        
    def update_target_network(self):
        self.target_model.load_state_dict(self.model.state_dict())
        
    def select_action(self, state, epsilon):
        if random.random() > epsilon:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float32)
                q_values = self.model(state)
                action = torch.argmax(q_values).item()
        else:
            action = random.randrange(self.action_dim)
        return action
    
    def learn(self, batch_size):
        if len(self.memory) < batch_size:
            return
        transitions = self.memory.sample(batch_size)
        batch_state, batch_action, batch_next_state, batch_reward, batch_done = zip(*transitions)

        batch_state = torch.tensor(batch_state, dtype=torch.float32)
        batch_action = torch.tensor(batch_action, dtype=torch.int64).unsqueeze(1)
        batch_reward = torch.tensor(batch_reward, dtype=torch.float32).unsqueeze(1)
        batch_next_state = torch.tensor(batch_next_state, dtype=torch.float32)
        batch_done = torch.tensor(batch_done, dtype=torch.float32).unsqueeze(1)
        
        current_q_values = self.model(batch_state).gather(1, batch_action)
        with torch.no_grad():
            max_next_q_values = self.target_model(batch_next_state).max(1)[0].unsqueeze(1)
            expected_q_values = batch_reward + self.gamma * max_next_q_values * (1 - batch_done)
            
        loss = self.criterion(current_q_values, expected_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def soft_update_target_network(self, tau=0.001):
        for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)