In [1]:

import torch 
import torch.nn as nn
from environment import BusLine
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pandas as pd 


class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = self.fc2(x)
        return x
    

state_size = 5
action_size = 2  # 0 for no departure, 1 for departure
agent = QNetwork(state_size, action_size)

state = [1, 2, 3, 4, 5]
agent.load_state_dict(torch.load('q_network.pth'))

passenger_path = "data/line3/passenger_dataframe_direction1.csv"
traffic_path = "data/line3/traffic-1.csv"

Line = BusLine()
current_minute = Line.current_minute
history = pd.DataFrame(columns=["Time","Action","Reward"])
while current_minute < Line.last_minute:
    # action from agent( network)

    state_tensor = torch.tensor(state, dtype=torch.float32)
    q_values = agent(state_tensor)
    action = torch.argmax(q_values).item()

    #  environment update and reward
    reward, new_state = Line.update_environment(action)
    history = history._append({"Time": current_minute,"Action":action,"Reward":reward},ignore_index=True)
    fig, ax = Line.plot()
    fig.savefig(f'plots/{current_minute}.png')
    current_minute += 1

In [5]:
import os
import cv2

def save_frames_as_video(frames_path, output_video_path, frame_rate=24):
    # Get the list of filenames and sort them numerically
    filenames = sorted(os.listdir(frames_path), key=lambda x: int(x.split('.')[0]))
    frame = cv2.imread(os.path.join(frames_path, filenames[0]))
    height, width, layers = frame.shape
    video = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width,height))

    for filename in filenames:
        if filename.endswith('.png'):  # Adjust this if your frames are saved with a different extension
            img = cv2.imread(os.path.join(frames_path, filename))
            video.write(img)

    cv2.destroyAllWindows()
    video.release()

# Example usage
frames_path = 'plots'  # Change this to the folder containing your frames
output_video_path = 'output.mp4'  # Change this to the desired output video file path
frame_rate = 5  # Change this to adjust the frame rate

# Call the function to create the video
save_frames_as_video(frames_path, output_video_path, frame_rate)
