Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYOL Single GPU implementation #1

Open
wants to merge 38 commits into
base: master
Choose a base branch
from

Conversation

pranavsinghps1
Copy link

Implementation of BYOL: https://arxiv.org/abs/2006.07733 on Single GPU
Issue #190

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

iseessel added a commit that referenced this pull request Jul 26, 2021
)

Summary:
Pull Request resolved: facebookresearch#343

Some basic changes to make this script work within FBinfra.

1. Register Manifold in PathManager.
1. In order to do #1, create fb/extra_scripts/convert_sharded_checkpoint.y and add necessary dependencies in TARGETS
1. Replace some torch.loads using PathManager.

Reviewed By: prigoyal

Differential Revision: D29166520

fbshipit-source-id: a61b4eb80d74526b0a7e2d38f973eb688b311a94
configs/config/dataset_catalog.json Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
configs/config/pretrain/byol/byol_1gpu.yaml Outdated Show resolved Hide resolved
@pranavsinghps1
Copy link
Author

As per the commit fc1217d addressed the following:

  • Changed the file name to byol_8node_resnet.yaml
  • Changed the key for imagenet1k_folder to Vissl Spec entry.
  • Parameterised img_pil_color_distortion.py to add brightness, contrast, saturation, hue, color_jitter_probability, and gray_probability.
  • Changed LOG_FREQUENCY to 200 and CHECKPOINT_FREQUENCY to 10.
  • Changed Momentum to 0.99
  • Adjusted values for Gaussian Blur and Solarization.
  • Added interpolation variable to Random to RandomResizedCrop in config.

vissl/hooks/byol_hooks.py Outdated Show resolved Hide resolved
vissl/losses/byol_loss.py Outdated Show resolved Hide resolved
def __init__(self, base_momentum: float, shuffle_batch: bool = True):
super().__init__()
self.base_momentum = base_momentum
self.is_distributed = False
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Do we need this?


class BYOLHook(ClassyHook):
"""
TODO: Update description
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update description to BYOL.


@staticmethod
def cosine_decay(training_iter, max_iters, initial_value):
# TODO: Why do we need this min statement?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment this method.

Put types in the method.

return initial_value * cosine_decay_value

@staticmethod
def target_ema(training_iter, base_ema, max_iters):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's comment every method in the byol_hooks.py and byol_losses.py and make sure they all have type hints.


def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
"""
Create the model replica called the target. This will slowly track
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve comment. Something like: "Target network is exponential moving average of online network, ... "

@torch.no_grad()
def on_forward(self, task: tasks.ClassyTask) -> None:
"""
- Update the target model.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment for BYOL (this was copy/pasted from moco).

@register_loss("byol_loss")
class BYOLLoss(ClassyLoss):
"""
TODO: change description
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change loss description.

and https://github.com/facebookresearch/moco for a reference implementation, reused here

Config params:
embedding_dim (int): head output output dimension
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need/use these vars?

"_BYOLLossConfig", ["embedding_dim", "momentum"]
)

def regression_loss(x, y):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type-hints + comments for all these functions.

@classmethod
def from_config(cls, config: BYOLLossConfig):
"""
Instantiates BYOLLoss from configuration.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in the config options in the docstring here.


def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Given the encoder queries, the key and the queue of the previous queries,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment I think is copy/pasted.

self.is_distributed = False

self.momentum = None
self.inv_momentum = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.


self.momentum = None
self.inv_momentum = None
self.total_iters = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this max_iters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants