In [1]:
import torch
import torch.nn as nn
from torchviz import make_dot

In [2]:
class HeatPINN(nn.Module):
    def __init__(self):
        super(HeatPINN, self).__init__()

        self.inpLayer = nn.Linear(3,50)
        self.fcLayer = nn.Linear(50,50)
        self.outLayer = nn.Linear(50,1)
        self.activationFunc = nn.SiLU()

    def forward(self, x):
        
        x = self.inpLayer(x)
        x = self.activationFunc(x)

        for _ in range(8):
            x = self.fcLayer(x)
            x = self.activationFunc(x)

        x = self.outLayer(x)

        return x
    
PINN = HeatPINN()

In [3]:
X = torch.tensor([0.5,0.5,0]).to(torch.float32)
y = PINN(X)

In [14]:
make_dot(y.mean(), params=dict(PINN.named_parameters())).render("simple", format="png")

'simple.png'

In [13]:
make_dot(y.mean(), params=dict(PINN.named_parameters()), show_attrs=True, show_saved=True).render("detailed", format="png")

'attached.png'

In [7]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("torchlogs/")
writer.add_graph(PINN, X)
writer.close()

In [10]:
!tensorboard --logdir=torchlogs

^C
