In [5]:
import torch
import pygame
import numpy as np
import cv2


# Parameters
width, height = 256, 256  # Grid size
D = 0.1  # Diffusion coefficient
nu = 0.01  # Kinematic viscosity
dt = 0.01  # Time step
scale = 2

# Desired FPS
fps = 60
# Set up video writer
video_filename = 'simulation_video.avi'
frame_size = (width*scale, height*scale)
video_writer = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'XVID'), fps, frame_size)
# Initialize density and velocity fields
rho = torch.rand((1, width, height), dtype=torch.float32).cuda()
u = torch.randn((2, width, height), dtype=torch.float32).cuda()  # 2 for 2D velocity components
print(rho.shape, u.shape)
class Update(torch.nn.Module):
    def __init__(self, D, nu, dt, alpha = 9):
        super().__init__()
        self.D = D
        self.nu = nu
        self.dt = dt
        self.alpha = alpha
        self.conv = torch.nn.Conv2d(1, 1, 3, padding=1, bias=False)
        self.conv.weight.data = torch.tensor([[[[1/self.alpha, 1/self.alpha, 1/self.alpha], [1/self.alpha, 0, 1/self.alpha], [1/self.alpha, 1/self.alpha, 1/self.alpha]]]], dtype=torch.float32)
        self.conv.weight.requires_grad = False

        self.rho_conv = torch.nn.Conv2d(3, 1, 3, padding=1, bias=False)
        self.u_conv = torch.nn.Conv2d(3, 2, 3, padding=1, bias=False)
        # self.u_conv.weight.data = torch.tensor([[[[1/self.alpha, 1/self.alpha, 1/self.alpha], [1/self.alpha, 0, 1/self.alpha], [1/self.alpha, 1/self.alpha, 1/self.alpha]]]], dtype=torch.float32)
    def forward(self, rho, u):
        # Implement discretization and update equations here
        # This is a placeholder for the actual logic
        # rho = rho + self.rho_conv(torch.cat((rho.unsqueeze(0), rho.unsqueeze(0), rho.unsqueeze(0)), dim=0))*self.dt*self.D
        rho_u = torch.cat((rho, u), dim=0)
        # print(rho_u.shape)
        rho = rho + self.rho_conv(rho_u)*self.dt*self.D
        u = u + self.u_conv(rho_u)*self.dt*self.nu
        return rho, u



torch.Size([1, 256, 256]) torch.Size([2, 256, 256])


In [6]:
# Pygame setup
pygame.init()
screen = pygame.display.set_mode((width*2, height*2))
update = Update(D, nu, dt).cuda()
running = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_t:  # 'T' key pressed
                update.rho_conv = torch.nn.Conv2d(3, 1, 3, padding=1, bias=False).cuda()
                update.u_conv = torch.nn.Conv2d(3, 2, 3, padding=1, bias=False).cuda()
                rho = torch.rand((1, width, height), dtype=torch.float32).cuda()
                u = torch.randn((2, width, height), dtype=torch.float32).cuda()  # 2 for 2D velocity components
                print("Convs updated")
    # Update rho and u
    rho, u = update(rho, u)
    # print(rho.shape, u.shape)
    # Visualization (simplified example)
    # Convert rho to a numpy array and scale it for visualization
# Assuming rho is your 2D array representing grayscale values
    vis_rho = np.clip(rho.cpu().detach().numpy(), 0, 1)  # Clipping the values between 0 and 1
    vis_rho = np.uint8(vis_rho * 255).repeat(3, axis=0).reshape(width, height, 3)  # Convert to 8-bit integer

    # Check the shape
    # print("Shape of vis_rho:", vis_rho.shape)

    # # Ensure it's 2D (for grayscale)
    # if len(vis_rho.shape) == 2:
    #     # Create the surface
    #     screen_surface = pygame.surfarray.make_surface(vis_rho)
    #     screen.blit(screen_surface, (0, 0))
    #     pygame.display.flip()
    # else:
    #     raise ValueError("vis_rho is not a valid 2D array")
    # Now create the surface
    screen_surface = pygame.surfarray.make_surface(vis_rho)
    screen_surface = pygame.transform.scale(screen_surface, (2*width, 2*height))
    screen.blit(screen_surface, (0, 0))
    pygame.display.flip()
    frame = pygame.surfarray.array3d(pygame.display.get_surface())
    frame = frame.transpose([1, 0, 2])  # Transpose it into the correct format for OpenCV
    video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))  # Convert to BGR for OpenCV
video_writer.release()    
pygame.quit()


Convs updated
