In [14]:
from simclr_data import load_unlabeled_dataset, generate_pair
from simclr_model import build_simclr_model
from simclr_loss import nt_xent_loss
import tensorflow as tf
import os


In [None]:


BATCH_SIZE = 32
unlabel_path = os.path.expanduser("~/Desktop/Dataset xray/Dataset unlabelled")
data_gen = load_unlabeled_dataset(unlabel_path, batch_size=BATCH_SIZE)

simclr_model = build_simclr_model()
optimizer = tf.keras.optimizers.Adam(1e-3)

def train_step(imgs):
    print("Raw batch shape from generator:", imgs.shape)
    x1, x2 = generate_pair(imgs)
    print("Augmented shapes:", x1.shape, x2.shape)
    with tf.GradientTape() as tape:
        z1 = simclr_model(x1, training=True)
        z2 = simclr_model(x2, training=True)
        loss = nt_xent_loss(z1, z2)
    grads = tape.gradient(loss, simclr_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, simclr_model.trainable_variables))
    return loss

for epoch in range(20):
    print(f"Epoch {epoch+1}/20")
    for batch in data_gen:
        if batch.shape[0] == 0:
            continue
        loss = train_step(batch)
        print(f"Loss: {loss:.4f}")
simclr_model.save("simclr_encoder.h5")



Found 9243 images belonging to 1 classes.
Epoch 1/20
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 4.0893
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.9754
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.9975
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.9867
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.8092
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.7328
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.5736
Raw batch shape from generator: (32, 224, 224, 3)
Augmented shapes: (32, 224, 224, 3) (32, 224, 224, 3)
Loss: 3.6677
Raw batch s

In [None]:
git config --global user.name "drishtiseth"
git config --global user.email "drishtiseth@gmail.com"
