Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subsampling on perceptual loss trick from fsrt paper #41

Open
johndpope opened this issue Jun 11, 2024 · 0 comments
Open

subsampling on perceptual loss trick from fsrt paper #41

johndpope opened this issue Jun 11, 2024 · 0 comments

Comments

@johndpope
Copy link
Owner

johndpope commented Jun 11, 2024

Screenshot from 2024-06-11 12-26-16

drafted
https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling

i dont see a massive speed up.

it's possible this could be randomized - sometimes go full / half / quarter etc.

class PerceptualLoss(nn.Module):
    def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}):
        super(PerceptualLoss, self).__init__()
        self.device = device
        self.weights = weights

        # VGG19 network
        vgg19 = models.vgg19(pretrained=True).features
        self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval()
        self.vgg19_layers = [1, 6, 11, 20, 29]

        # VGGFace network
        self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval()
        self.vggface_layers = [4, 5, 6, 7]

        # Gaze loss
        self.gaze_loss = MPGazeLoss(device)



    # Trick shot to reduce memory 3.3 - use random sub_sample
    # https://arxiv.org/pdf/2404.09736#page=5.58
    def forward(self, predicted, target, sub_sample_size=(128, 128),use_fm_loss=False):
        # Normalize input images
        predicted = self.normalize_input(predicted)
        target = self.normalize_input(target)

        # Compute VGG19 perceptual loss
        vgg19_loss = self.compute_vgg19_loss(predicted, target)

        # Compute VGGFace perceptual loss
        vggface_loss = self.compute_vggface_loss(predicted, target)

        # Compute gaze loss
        # gaze_loss = self.gaze_loss(predicted, target)

        # Compute total perceptual loss
        total_loss = (
            self.weights['vgg19'] * vgg19_loss +
            self.weights['vggface'] * vggface_loss +
            self.weights['gaze'] * 1 #gaze_loss
        )

        if use_fm_loss:
            # Compute feature matching loss
            fm_loss = self.compute_feature_matching_loss(predicted, target)
            total_loss += fm_loss

        return total_loss

    def sub_sample_tensor(self, tensor, sub_sample_size):
        assert tensor.ndim == 4, "Input tensor should have 4 dimensions (batch_size, channels, height, width)"
        assert tensor.shape[-2] >= sub_sample_size[0] and tensor.shape[-1] >= sub_sample_size[1], "Sub-sample size should not exceed the tensor dimensions"

        batch_size, channels, height, width = tensor.shape
        # randomly sample so we cover all the image over training.
        random_offset_x = np.random.randint(0, height - sub_sample_size[0])
        random_offset_y = np.random.randint(0, width - sub_sample_size[1])

        sub_sampled_tensor = tensor[..., random_offset_x:random_offset_x+sub_sample_size[0], random_offset_y:random_offset_y+sub_sample_size[1]]

        return sub_sampled_tensor

    def compute_vgg19_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target)

    def compute_vggface_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target)

    def compute_feature_matching_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True)

    def compute_perceptual_loss(self, model, layers, predicted, target, detach=False):
        loss = 0.0
        predicted_features = predicted
        target_features = target
        #print(f"predicted_features:{predicted_features.shape}")
        #print(f"target_features:{target_features.shape}")

        for i, layer in enumerate(model.children()):
            # print(f"i{i}")
            if isinstance(layer, nn.Conv2d):
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)
            elif isinstance(layer, nn.Linear):
                predicted_features = predicted_features.view(predicted_features.size(0), -1)
                target_features = target_features.view(target_features.size(0), -1)
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)
            else:
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)

            if i in layers:
                if detach:
                    loss += torch.mean(torch.abs(predicted_features - target_features.detach()))
                else:
                    loss += torch.mean(torch.abs(predicted_features - target_features))

        return loss

    def normalize_input(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
        return (x - mean) / std
@johndpope johndpope changed the title subsampling on loss trick from fsrt paper subsampling on perceptual loss trick from fsrt paper Jun 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant