In [264]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal_, kaiming_uniform_


In [265]:
torch.manual_seed(1244)
input = torch.randn(100, dtype=torch.float, device='mps', requires_grad=True)
y = torch.zeros_like(input)
y[0] = 1
input, y

(tensor([-2.5734, -1.2485,  0.5629,  0.7896, -0.4293,  0.9726, -1.4301,  0.6779,
          0.8384, -0.7334,  0.7964,  1.6117, -0.9399,  1.1813,  0.0902, -0.8812,
         -0.8176, -0.9856, -0.2382, -1.0007, -1.5119,  0.6298, -0.0136, -0.4105,
         -0.7050, -0.0282,  0.3113,  0.8288, -0.5914, -0.6119, -0.8549, -0.7866,
          0.8864,  0.8158,  0.5732,  1.3631, -0.2000,  0.9504, -2.1211,  0.5665,
          0.7490,  1.0691,  0.1702, -0.8173, -0.6776, -1.5998, -1.1331, -0.9786,
          0.4729, -0.3974, -0.9547, -1.0017, -0.7745,  1.0788, -1.1509,  1.1367,
         -0.0552,  0.1882, -0.0894,  0.8276, -0.3104,  0.6693, -0.9399, -0.7496,
         -1.1902, -0.7327,  1.8143, -0.6053,  0.0726, -0.1611, -0.9829, -0.6636,
         -0.5799,  0.4069, -2.1520, -1.2220,  0.1452, -0.8232,  0.2868,  0.2490,
         -0.7190,  0.2694, -2.6760, -1.1018,  1.1614, -1.0078, -0.4168,  0.0146,
         -0.2599, -0.3227,  0.6746, -0.1666,  0.6217, -1.0345,  0.5248, -1.3923,
          0.6642, -1.9983,  

In [266]:
class Module():
    def __init__(self, linear_in, linear_out):
        # self.w = torch.randn(linear_in, linear_out, requires_grad=True, device='mps')
        self.w = torch.empty(linear_in, linear_out, device='mps')
        kaiming_uniform_(self.w)
        self.w.requires_grad_(True)
    def __call__(self, x):
        a = x @ self.w
        a = F.gelu(a)
        return torch.softmax(a, 0)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(100, 100, bias=False, device='mps', dtype=torch.float)

    def forward(self, x):
        x = F.gelu(self.linear1(x))
        return torch.softmax(x, 0)

module = Module(100, 100)
myModule = MyModel()


In [267]:
out = module(input)
print(module.w)
Loss = torch.sum((out - y).pow(2))  # pow(2) is equivalent to squaring
print(Loss)

Loss.backward()
# module.w.requires_grad_(False)
with torch.no_grad():
    module.w -= 1.0 * module.w.grad
    module.w.grad.zero_()
# module.w.requires_grad_(True)

tensor([[ 0.0299,  0.1393,  0.2279,  ...,  0.2317,  0.0913, -0.0622],
        [ 0.0544,  0.2054, -0.2110,  ..., -0.2354,  0.2142, -0.0552],
        [-0.1136, -0.1403,  0.2341,  ...,  0.2169, -0.1314,  0.1425],
        ...,
        [-0.0288, -0.1261,  0.1990,  ..., -0.0904, -0.0309,  0.1893],
        [-0.1185, -0.2103, -0.1323,  ..., -0.2065,  0.1864, -0.1170],
        [-0.0663,  0.0811, -0.0772,  ...,  0.1881,  0.1048,  0.2442]],
       device='mps:0', requires_grad=True)
tensor(1.0430, device='mps:0', grad_fn=<SumBackward0>)


In [268]:
out = myModule(input)
# print(out)
Loss = torch.sum((out - y).pow(2))  # pow(2) is equivalent to squaring
print(Loss)

Loss.backward()
with torch.no_grad():
    for p in myModule.linear1.parameters():
        p -= 1.0 * p.grad
        p.grad.zero_()
        print(p)

tensor(0.9955, device='mps:0', grad_fn=<SumBackward0>)
Parameter containing:
tensor([[-0.0890,  0.0133,  0.0824,  ...,  0.0540, -0.0632,  0.0349],
        [-0.0251,  0.0142, -0.0222,  ..., -0.0363, -0.0629,  0.0124],
        [ 0.0216,  0.0070, -0.0396,  ...,  0.0342, -0.0881, -0.0594],
        ...,
        [ 0.0995,  0.0429,  0.0062,  ...,  0.0797, -0.0206,  0.0346],
        [-0.0702,  0.0053,  0.0986,  ...,  0.0793,  0.0175,  0.0300],
        [-0.0371,  0.0388, -0.0105,  ..., -0.0840,  0.0536, -0.0370]],
       device='mps:0', requires_grad=True)


In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

In [None]:
# Memory Profiling

with profile(activities=[ProfilerActivity.CUDA],
        profile_memory=True, record_shapes=True) as prof:
    out = myModule(input)

print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))