<a href="https://colab.research.google.com/github/hadwin-357/ProteinMPNN_breakdown/blob/main/testing_model_utils_functions_loss_smoothed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# testing_model_util_functions
# loss_smooth function
'''
inputs-  S ground truth, log_probs: calculated logprob, mask: position mask(if masked 1, else 0),
weights: smoothing labeling weights (regularization parameteres)

Function explain
Turn grond truth into a onehot encode
add noise to onehot encode, then normalize (regularization)
calculate loss  (similar to cross entropy)
caclaculate average loss
'''

In [3]:
import torch

In [4]:
def loss_smoothed(S, log_probs, mask, weight=0.1):
    """ Negative log probabilities """
    S_onehot = torch.nn.functional.one_hot(S, 21).float()

    # Label smoothing
    S_onehot = S_onehot + weight / float(S_onehot.size(-1))
    S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)

    loss = -(S_onehot * log_probs).sum(-1)
    loss_av = torch.sum(loss * mask) / 2000.0 #fixed
    return loss, loss_av

In [5]:
# break down
# S is the ground truth


# Example ground truth labels
# Suppose you have a batch size of 4 and sequence length of 10
batch_size = 4
sequence_length = 10
num_classes = 21  # Assuming 21 classes

# Create random ground truth labels
S = torch.randint(0, num_classes, (batch_size, sequence_length))

# Verify the shape and data type of S
print("Shape of S:", S.shape)
print("Data type of S:", S.dtype)


Shape of S: torch.Size([4, 10])
Data type of S: torch.int64


In [9]:
S_onehot = torch.nn.functional.one_hot(S, 21).float()
S_onehot.shape

torch.Size([4, 10, 21])

In [10]:
# Label smoothing/Regularization
weight=0.1
S_onehot = S_onehot + weight / float(S_onehot.size(-1))
S_onehot[0,0,:]

tensor([0.0048, 0.0048, 0.0048, 0.0048, 0.0048, 0.0048, 0.0048, 0.0048, 0.0048,
        0.0048, 0.0048, 0.0048, 0.0048, 1.0048, 0.0048, 0.0048, 0.0048, 0.0048,
        0.0048, 0.0048, 0.0048])

In [11]:
#normalized so that all probilities add up to 1
S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)
S_onehot[0,0,:]


tensor([0.0043, 0.0043, 0.0043, 0.0043, 0.0043, 0.0043, 0.0043, 0.0043, 0.0043,
        0.0043, 0.0043, 0.0043, 0.0043, 0.9134, 0.0043, 0.0043, 0.0043, 0.0043,
        0.0043, 0.0043, 0.0043])

In [13]:
# Example predicted probabilities (logits)
# Suppose you have a batch size of 4, sequence length of 10, and 21 classes
batch_size = 4
sequence_length = 10
num_classes = 21

# Create random logits (predicted probabilities)
logits = torch.randn(batch_size, sequence_length, num_classes)

# Verify the shape and data type of logits
print("Shape of logits:", logits.shape)
print("Data type of logits:", logits.dtype)


Shape of logits: torch.Size([4, 10, 21])
Data type of logits: torch.float32


In [14]:
logits[0,0,:]

tensor([-0.1022, -0.4064,  0.6391, -0.7980, -0.0428, -0.5432, -0.8819, -0.2547,
         0.1194, -1.0239, -0.0670, -0.4794,  0.7475,  0.7643, -2.1898, -0.1602,
         0.8382, -0.5503, -0.2953, -0.9511,  0.4750])

In [15]:
log_probs=logits
loss = -(S_onehot * log_probs).sum(-1) #
loss.shape

torch.Size([4, 10])

In [16]:
# Example mask
# Suppose you have a batch size of 4 and sequence length of 10
batch_size = 4
sequence_length = 10

# Create a random mask where only the first 8 positions are valid
mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool)
mask[:, :8] = True

# Verify the shape and data type of mask
print("Shape of mask:", mask.shape)
print("Data type of mask:", mask.dtype)


Shape of mask: torch.Size([4, 10])
Data type of mask: torch.bool


In [17]:
loss_av = torch.sum(loss * mask) / 2000.0

In [18]:
loss_av

tensor(-0.0020)