Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
64b7108
Added barlow twins example
dewball345 Nov 7, 2021
81a4647
Delete barlow_twins.py
dewball345 Nov 7, 2021
761ef39
Add files via upload
dewball345 Nov 7, 2021
e806385
Update barlow_twins.py
dewball345 Nov 8, 2021
34b0e3e
Update barlow_twins.py
dewball345 Nov 9, 2021
f065598
Added updates
dewball345 Nov 18, 2021
aa41101
Capitalization updates
dewball345 Nov 18, 2021
2478bc0
added minor changes, subclassing api
dewball345 Nov 24, 2021
2051a05
Update barlow_twins.py
dewball345 Nov 24, 2021
32dce72
fixed remaining issues
dewball345 Nov 24, 2021
00a2605
parenthesis
dewball345 Nov 29, 2021
ab65cb6
two periods
dewball345 Dec 1, 2021
153e6c1
formatted
dewball345 Dec 1, 2021
d5d73d7
Syntax error removed
dewball345 Dec 3, 2021
eb3dcb5
added md
dewball345 Dec 3, 2021
340cb67
Add files via upload
dewball345 Dec 3, 2021
7830f52
Add files via upload
dewball345 Dec 3, 2021
6f3d9b3
fixes
dewball345 Dec 20, 2021
02f9ace
Update barlow_twins.py
dewball345 Dec 21, 2021
72bac37
Update barlow_twins.ipynb
dewball345 Dec 21, 2021
14d71de
Delete barlow_twins_18_0.png
dewball345 Dec 21, 2021
5c1961b
Delete barlow_twins_18_1.png
dewball345 Dec 21, 2021
3f4a16d
Delete barlow_twins_32_1.png
dewball345 Dec 21, 2021
bb98696
Add files via upload
dewball345 Dec 21, 2021
cd6959c
Update barlow_twins.py
dewball345 Dec 21, 2021
b9890a2
Update barlow_twins.md
dewball345 Dec 21, 2021
742e5ee
Update barlow_twins.ipynb
dewball345 Dec 21, 2021
80a7335
new updates
dewball345 Jan 29, 2022
67a12ad
make it follow black conventions
dewball345 Jan 29, 2022
9e459ec
Update barlow_twins.py
dewball345 Feb 6, 2022
50931e5
Update barlow_twins.py
fchollet Feb 9, 2022
2bbdbf6
Create vicreg.py
dewball345 Apr 14, 2022
286f824
not for vicreg
dewball345 May 30, 2022
54c75dc
not for vicreg
dewball345 May 30, 2022
1036b1e
not for vicreg
dewball345 May 30, 2022
c833881
Use master branch of tensorflow similarity
dewball345 May 30, 2022
f527268
Add files via upload
dewball345 Jun 17, 2022
5464203
Add files via upload
dewball345 Jun 17, 2022
a8ae355
Create t
dewball345 Jun 17, 2022
b25452d
Add files via upload
dewball345 Jun 17, 2022
506f023
Delete t
dewball345 Jun 17, 2022
bf6583e
Update vicreg.md
dewball345 Jun 21, 2022
5147163
Update vicreg.md
dewball345 Jun 21, 2022
533b679
fixed deletion
dewball345 Aug 9, 2023
5b0c43a
seems to work
dewball345 Aug 9, 2023
979102f
Merge branch 'master' into vicreg-branch
dewball345 Aug 9, 2023
b68d18a
black
dewball345 Aug 9, 2023
a8b5546
Merge branch 'vicreg-branch' of https://github.com/dewball345/keras-i…
dewball345 Aug 9, 2023
b0658fc
black again
dewball345 Aug 9, 2023
30ed364
Update barlow_twins.py
dewball345 Sep 4, 2023
3cd5c1f
Update barlow_twins.ipynb
dewball345 Sep 4, 2023
ea7b585
done
dewball345 Jul 16, 2024
27597be
done
dewball345 Jul 16, 2024
4435263
date
dewball345 Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions examples/vision/barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
Title: Barlow Twins for Contrastive SSL
Author: [Abhiraam Eranti](https://github.com/dewball345)
Date created: 11/4/21
Last modified: 12/20/21
Last modified: 7/16/24
Description: A keras implementation of Barlow Twins (constrastive SSL with redundancy reduction).
Accelerator: GPU
"""

"""
## Introduction
"""

"""
Self-supervised learning (SSL) is a relatively novel technique in which a model
learns from unlabeled data, and is often used when the data is corrupted or
if there is very little of it. A practical use for SSL is to create
Expand Down Expand Up @@ -51,7 +50,7 @@
Also, it is simpler than other methods.

This notebook can train a Barlow Twins model and reach up to
64% validation accuracy on the CIFAR-10 dataset.
67% validation accuracy on the CIFAR-10 dataset.
"""

"""
Expand All @@ -61,13 +60,11 @@




"""

"""
### High-Level Theory


"""

"""
Expand Down Expand Up @@ -127,7 +124,6 @@
Original Implementation:
[facebookresearch/barlowtwins](https://github.com/facebookresearch/barlowtwins)


"""

"""
Expand All @@ -136,6 +132,7 @@

"""shell
pip install tensorflow-addons
pip install --upgrade-strategy=only-if-needed tensorflow_similarity
"""

import os
Expand All @@ -153,6 +150,7 @@
import numpy as np # np.random.random
import matplotlib.pyplot as plt # graphs
import datetime # tensorboard logs naming
import tensorflow_similarity # loss function module

# XLA optimization for faster performance(up to 10-15 minutes total time saved)
tf.config.optimizer.set_jit(True)
Expand Down Expand Up @@ -605,9 +603,15 @@ def plot_values(batch: tuple):

After this the two parts are summed together.

We will be using the
[BarlowLoss](https://github.com/tensorflow/similarity/blob/master/tensorflow_similarity/losses/barlow.py)
module from Tensorflow Similarity

A from-scratch implementation is also included below.
"""


"""
### From-Scratch implementation (for understanding purposes)
"""


Expand Down Expand Up @@ -767,18 +771,17 @@ class to make the cross corr matrix, then finds the loss and
"""

"""
Resnet encoder network implementation:
### Resnet encoder network implementation:
"""


class ResNet34:
"""Resnet34 class.

Responsible for the Resnet 34 architecture.
Responsible for the Resnet 34 architecture.
Modified from
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
View their website for more information.
View their website for more information.
"""

def identity_block(self, x, filter):
Expand Down Expand Up @@ -846,7 +849,7 @@ def __call__(self, shape=(32, 32, 3)):


"""
Projector network:
### Projector network:
"""


Expand Down Expand Up @@ -960,7 +963,9 @@ def train_step(self, batch: tf.Tensor) -> tf.Tensor:
# chose the LAMB optimizer due to high batch sizes. Converged MUCH faster
# than ADAM or SGD
optimizer = tfa.optimizers.LAMB()
loss = BarlowLoss(BATCH_SIZE)

# We can just drop in either one of the two(results will be similar for both)
loss = tensorflow_similarity.losses.Barlow() # BarlowLoss(BATCH_SIZE)

bm.compile(optimizer=optimizer, loss=loss)

Expand Down Expand Up @@ -1023,12 +1028,12 @@ def train_step(self, batch: tf.Tensor) -> tf.Tensor:

* Barlow Twins is a simple and concise method for contrastive and self-supervised
learning.
* With this resnet-34 model architecture, we were able to reach 62-64% validation
* With this resnet-34 model architecture, we were able to reach 67% validation
accuracy.

## Use-Cases of Barlow-Twins(and contrastive learning in General)

* Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy
* Semi-supervised learning: You can see that this model gave a 67% boost in accuracy
when it wasn't even trained with the labels. It can be used when you have little labeled
data but a lot of unlabeled data.
* You do barlow twins training on the unlabeled data, and then you do secondary training
Expand All @@ -1045,5 +1050,4 @@ def train_step(self, batch: tf.Tensor) -> tf.Tensor:
* Thanks to Yashowardhan Shinde for writing the article.



"""
50 changes: 33 additions & 17 deletions examples/vision/ipynb/barlow_twins.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"Also, it is simpler than other methods.\n",
"\n",
"This notebook can train a Barlow Twins model and reach up to\n",
"64% validation accuracy on the CIFAR-10 dataset."
"67% validation accuracy on the CIFAR-10 dataset."
]
},
{
Expand All @@ -82,7 +82,6 @@
"\n",
"\n",
"\n",
"\n",
""
]
},
Expand All @@ -92,8 +91,7 @@
"colab_type": "text"
},
"source": [
"### High-Level Theory\n",
""
"### High-Level Theory"
]
},
{
Expand Down Expand Up @@ -171,8 +169,7 @@
"Reduction](https://arxiv.org/abs/2103.03230)\n",
"\n",
"Original Implementation:\n",
" [facebookresearch/barlowtwins](https://github.com/facebookresearch/barlowtwins)\n",
""
" [facebookresearch/barlowtwins](https://github.com/facebookresearch/barlowtwins)"
]
},
{
Expand All @@ -192,7 +189,8 @@
},
"outputs": [],
"source": [
"!pip install tensorflow-addons"
"!!pip install tensorflow-addons\n",
"!!pip install --upgrade-strategy=only-if-needed tensorflow_similarity"
]
},
{
Expand All @@ -218,6 +216,7 @@
"import numpy as np # np.random.random\n",
"import matplotlib.pyplot as plt # graphs\n",
"import datetime # tensorboard logs naming\n",
"import tensorflow_similarity # loss function module\n",
"\n",
"# XLA optimization for faster performance(up to 10-15 minutes total time saved)\n",
"tf.config.optimizer.set_jit(True)"
Expand Down Expand Up @@ -761,8 +760,20 @@
"\n",
"After this the two parts are summed together.\n",
"\n",
"We will be using the [BarlowLoss](https://github.com/tensorflow/similarity/blob/master/tensorflow_similarity/lo\n",
"sses/barlow.py)\n",
"module from Tensorflow Similarity\n",
"\n",
""
"A from-scratch implementation is also included below."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### From-Scratch implementation(for understanding purposes)"
]
},
{
Expand Down Expand Up @@ -941,7 +952,7 @@
"colab_type": "text"
},
"source": [
"Resnet encoder network implementation:"
"### Resnet encoder network implementation:"
]
},
{
Expand All @@ -956,11 +967,13 @@
"class ResNet34:\n",
" \"\"\"Resnet34 class.\n",
"\n",
" Responsible for the Resnet 34 architecture.\n",
" Modified from\n",
" Responsible for the Resnet 34 architecture.\n",
" Modified from\n",
" https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.\n",
" https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.\n",
" https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.\n",
" https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.\n",
" View their website for more information.\n",
" View their website for more information.\n",
" \"\"\"\n",
"\n",
" def identity_block(self, x, filter):\n",
Expand Down Expand Up @@ -1034,7 +1047,7 @@
"colab_type": "text"
},
"source": [
"Projector network:"
"### Projector network:"
]
},
{
Expand Down Expand Up @@ -1184,7 +1197,9 @@
"# chose the LAMB optimizer due to high batch sizes. Converged MUCH faster\n",
"# than ADAM or SGD\n",
"optimizer = tfa.optimizers.LAMB()\n",
"loss = BarlowLoss(BATCH_SIZE)\n",
"\n",
"# We can just drop in either one of the two(results will be similar for both)\n",
"loss = tensorflow_similarity.losses.Barlow() # BarlowLoss(BATCH_SIZE)\n",
"\n",
"bm.compile(optimizer=optimizer, loss=loss)\n",
"\n",
Expand Down Expand Up @@ -1267,12 +1282,12 @@
"\n",
"* Barlow Twins is a simple and concise method for contrastive and self-supervised\n",
"learning.\n",
"* With this resnet-34 model architecture, we were able to reach 62-64% validation\n",
"* With this resnet-34 model architecture, we were able to reach 67% validation\n",
"accuracy.\n",
"\n",
"## Use-Cases of Barlow-Twins(and contrastive learning in General)\n",
"\n",
"* Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy\n",
"* Semi-supervised learning: You can see that this model gave a 67% boost in accuracy\n",
"when it wasn't even trained with the labels. It can be used when you have little labeled\n",
"data but a lot of unlabeled data.\n",
"* You do barlow twins training on the unlabeled data, and then you do secondary training\n",
Expand All @@ -1284,12 +1299,13 @@
"* [Original Pytorch Implementation](https://github.com/facebookresearch/barlowtwins)\n",
"* [Sayak Paul's\n",
"Implementation](https://colab.research.google.com/github/sayakpaul/Barlow-Twins-TF/blob/main/Barlow_Twins.ipynb#scrollTo=GlWepkM8_prl).\n",
"Implementation](https://colab.research.google.com/github/sayakpaul/Barlow-Twins-TF/blob/main/Barlow_Twins.ipynb#scrollTo=GlWepkM8_prl).\n",
"* Thanks to Sayak Paul for his implementation. It helped me with debugging and\n",
"comparisons of accuracy, loss.\n",
"* [resnet34\n",
"implementation](https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2)\n",
"implementation](https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2)\n",
" * Thanks to Yashowardhan Shinde for writing the article.\n",
"\n",
""
]
}
Expand Down
Loading