In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
class Model(nn.Module):
    def __init__(self, num_classes=3):
        super(Model, self).__init__()
        
      
        self.conv1 = nn.Conv1d(1, 32, 7, padding=3)
        self.in1 = nn.InstanceNorm2d(32)  
        
        self.conv2 = nn.Conv1d(32, 64, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(64)
        
        self.conv3 = nn.Conv1d(64, 128, 3, padding=1)
        self.in3 = nn.InstanceNorm2d(128)
        
        self.pool = nn.AvgPool1d(3)

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = F.relu(self.in1(self.conv1(x)))
        x = self.pool(x)
        
        x = F.relu(self.in2(self.conv2(x)))
        x = self.pool(x)
        
        x = F.relu(self.in3(self.conv3(x)))
        x = self.pool(x)
        
        x = self.global_pool(x)
        x = torch.flatten(x, start_dim =1)
        return self.fc(x)

In [15]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

mps_device = torch.device("mps")

tensor([1.], device='mps:0')


In [17]:
#from simple_model import Model

# Load model and weights
model = Model()
model.to(mps_device)
model.load_state_dict(torch.load('80devf1.pt', map_location=mps_device))

print(next(model.parameters()).device)

# List of layer names to freeze
layers_to_freeze = ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']

# Go through layers freeze if in frezze list
for name, param in model.named_parameters():
    if name in layers_to_freeze:
        param.requires_grad = False

for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

model.eval()

mps:0
conv1.weight: False
conv1.bias: False
conv2.weight: False
conv2.bias: False
conv3.weight: True
conv3.bias: True
fc.weight: True
fc.bias: True


  model.load_state_dict(torch.load('80devf1.pt', map_location=mps_device))


Model(
  (conv1): Conv1d(1, 32, kernel_size=(7,), stride=(1,), padding=(3,))
  (in1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv2): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (in2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv3): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (in3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (pool): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))
  (global_pool): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=128, out_features=3, bias=True)
)